使用keras上的VGG16模型对ImageNet的训练结果进行特征提取,并在猫狗分类中应用,同时进行了数据增强。代码如下:
from keras import models from keras import layers from keras import optimizers from keras.applications import VGG16 from keras.preprocessing.image import ImageDataGenerator import matplotlib.pyplot as plt conv_base = VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3)) #建立模型 model = models.Sequential() model.add(conv_base) model.add(layers.Flatten()) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dense(1, activation='sigmoid')) print(model.summary()) print(len(model.trainable_weights)) #冻结卷积基 conv_base.trainable = False print(len(model.trainable_weights)) #猫狗图片集,训练集2000张,验证和测试集各1000张 train_dir = './datasets/train/' validation_dir = './datasets/validation' test_dir = './datasets/test'
#数据增强 train_datagen = ImageDataGenerator(
rescale=1./255, rotation_range=40, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode='nearest' ) test_datagen = ImageDataGenerator(rescale=1./255) train_generator = train_datagen.flow_from_directory( train_dir, target_size=(150,150), batch_size=20, class_mode='binary' ) validation_generator = test_datagen.flow_from_directory( validation_dir, target_size=(150,150), batch_size=20, class_mode='binary' ) model.compile(optimizer=optimizers.RMSprop(lr=2e-5), loss='binary_crossentropy', metrics=['acc']) history = model.fit_generator( train_generator,steps_per_epoch=100, epochs=30, validation_data=validation_generator, validation_steps=50 ) model.save('cat_and_dog_pre_train_gpu.h5') acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(1, len(acc)+1) plt.plot(epochs, acc, 'bo', label='Traing acc') plt.plot(epochs, val_acc, 'b', label='Validation acc') plt.title('Training and validation accuracy') plt.legend() plt.figure() plt.plot(epochs, loss, 'bo', label='Traing loss') plt.plot(epochs, val_loss, 'b', label='Validation loss') plt.title('Training and validation loss') plt.legend() plt.show()