预训练+微调

Note

使用预训练模型就像是站在巨人的肩膀上,现在很多模型都是预训练+微调的模式。
本节我们使用在Fashion-MNIST上训练好的模型来训练MNIST。

获得模型和数据

from tensorflow import keras

# 之前在Fashion-MNIST上训练的模型
model_A = keras.models.load_model("my_fashion_mnist_model")
# 前面用预训练的数据,最后一层从头开始
model_B_on_A = keras.models.Sequential(model_A.layers[:-1])
model_B_on_A.add(keras.layers.Dense(10, activation="sigmoid"))
# 载入MNIST数据集
(X_train_val, y_train_val), (X_test, y_test) = keras.datasets.mnist.load_data()

X_val, X_train = X_train_val[:5000] / 255., X_train_val[5000:] / 255.
y_val, y_train = y_train_val[:5000], y_train_val[5000:]
X_test = X_test / 255.

冻结预训练层

# 冻结pretrain layers
for layer in model_B_on_A.layers[:-1]:
    layer.trainable = False

# 编译
model_B_on_A.compile(loss="sparse_categorical_crossentropy", 
                     optimizer=keras.optimizers.SGD(learning_rate=1e-2),
                     metrics=["accuracy"])
# 训练
history = model_B_on_A.fit(X_train, y_train, 
                           epochs=5,
                           validation_data=(X_val, y_val))
Epoch 1/5
1719/1719 [==============================] - 2s 887us/step - loss: 1.4689 - accuracy: 0.5684 - val_loss: 1.1026 - val_accuracy: 0.7056
Epoch 2/5
1719/1719 [==============================] - 1s 809us/step - loss: 1.0053 - accuracy: 0.7226 - val_loss: 0.9022 - val_accuracy: 0.7530
Epoch 3/5
1719/1719 [==============================] - 1s 805us/step - loss: 0.8703 - accuracy: 0.7576 - val_loss: 0.8063 - val_accuracy: 0.7794
Epoch 4/5
1719/1719 [==============================] - 1s 814us/step - loss: 0.7954 - accuracy: 0.7759 - val_loss: 0.7435 - val_accuracy: 0.8024
Epoch 5/5
1719/1719 [==============================] - 1s 818us/step - loss: 0.7451 - accuracy: 0.7887 - val_loss: 0.7010 - val_accuracy: 0.8128

微调预训练层

在冻结预训练层进行训练后,我们可以放开限制进行微调,注意微调时要用较小的学习率。

# 解冻
for layer in model_B_on_A.layers[:-1]:
    layer.trainable = True

# 使用较小的学习率重新编译
model_B_on_A.compile(loss="sparse_categorical_crossentropy",
                     optimizer=keras.optimizers.SGD(learning_rate=1e-3),
                     metrics=["accuracy"])
# 训练
history = model_B_on_A.fit(X_train, y_train, 
                           epochs=5,
                           validation_data=(X_val, y_val))
Epoch 1/5
1719/1719 [==============================] - 3s 2ms/step - loss: 0.5269 - accuracy: 0.8498 - val_loss: 0.4049 - val_accuracy: 0.8918
Epoch 2/5
1719/1719 [==============================] - 3s 2ms/step - loss: 0.3861 - accuracy: 0.8912 - val_loss: 0.3336 - val_accuracy: 0.9094
Epoch 3/5
1719/1719 [==============================] - 3s 2ms/step - loss: 0.3340 - accuracy: 0.9055 - val_loss: 0.2983 - val_accuracy: 0.9196
Epoch 4/5
1719/1719 [==============================] - 3s 2ms/step - loss: 0.3036 - accuracy: 0.9142 - val_loss: 0.2760 - val_accuracy: 0.9248
Epoch 5/5
1719/1719 [==============================] - 3s 2ms/step - loss: 0.2822 - accuracy: 0.9202 - val_loss: 0.2600 - val_accuracy: 0.9284