Keras2.0 使用预训练模型权重创建网络的两种方式

Keras2.0 使用预训练模型权重创建网络的两种方式

使用预训练模型进行迁移学习,是一种常见的资源利用,加速训练的方法。
在图形分类问题中常用的网络有VGG16、VGG19、ResNet50等
Keras内置了这些常见经典算法的imagenet数据集上的训练权重。

迁移学习的基本模式是:

  • 加载预训练权重,但是忽略最上层的全连接网络(Network(weights=”imagenet”, include_top=False,
    input_tensor=Input(shape=(64, 64, 3)))).
  • 在加载的模型上根据自己的需要添加网络层
  • 设置预训练权重的各层(layer)训练中不更新(layer.trainable=False)。
  • 重新训练

如何把预加载权重的网络和自己新建的网络连接在一起,有两种常见的写法。
写法一:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#加载预训练模型
pretrainedModel = VGG16(weights="imagenet", include_top=False,
input_tensor=Input(shape=(64, 64, 3)))
#添加自己的top层
top_model = pretrainedModel.output
top_model = Flatten(name='flatten')(top_model)
top_model = Dense(256, activation='relu')(top_model)
top_model = Dropout(0.5)(top_model)
top_model = Dense(2, activation='sigmoid')(top_model)
#构建完整的模型
model = Model(inputs=pretrainedModel.input, outputs=top_model)
#锁定预训练层权重
for layer in modelResnet50.layers:
layer.trainable = False
#编译模型
opt = SGD(lr=INIT_LR, momentum=0.9)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
#训练模型
callbacks = [LearningRateScheduler(poly_decay)]
H = model.fit_generator(trainGen, steps_per_epoch=totalTrain // BS,
validation_data=valGen, validation_steps=totalVAl // BS,
epochs=NUM_EPOCHS, callbacks=callbacks)

方法二:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#加载预训练模型
input_tensor = Input(shape=(64,64,3))
modelVgg16 = applications.VGG16(weights='imagenet', include_top=False, input_tensor=input_tensor)
#重建一个顶层网络
top_model.add(Flatten(input_shape=modelVgg16.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(1, activation='sigmoid'))
#创建模型
model = Model(inputs=modelVgg16.input, output=top_model(modelVgg16.output))

#锁定预训练层
for layer in model.layers[:25]:
layer.trainable = False

# compile the model with a SGD/momentum optimizer
# and a very slow learning rate.
model.compile(loss='binary_crossentropy',
optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
metrics=['accuracy'])
model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples // batch_size,
validation_data=validation_generator,
validation_steps = nb_validation_samples // batch_size,
epochs=epochs)

第二种方式更简单一些,1.0版本里面可以直接add模型,2.0版本里面没有响应的api接口了。所以需要重新生成一个模型对象。