自定义训练循环¶
Note
在极少数情况下,fit()方法可能不够灵活而无法满足需要,这时就需要自定义训练循环。
准备工作¶
用例子来说明
import tensorflow as tf
from tensorflow import keras
import utils
# 获取模型和数据集
model = keras.models.load_model("my_housing_model")
(X_train, y_train), (X_val, y_val), (X_test, y_test) = utils.load_california_housing()
import numpy as np
def random_batch(X, y, batch_size=32):
# 获得随机的batch
idx = np.random.randint(len(X), size=batch_size)
return X[idx], y[idx]
def print_status_bar(iteration, total, loss, metrics=None):
# 打印损失和指标
metrics = " - ".join(["{}: {:.4f}".format(m.name, m.result())
for m in [loss] + (metrics or [])])
# 每个epoch换一行
end = "" if iteration < total else "\n"
# \r将光标移到行首
print("\r{}/{} - ".format(iteration, total) + metrics, end=end)
n_epochs = 5
batch_size = 32
n_steps = len(X_train) // batch_size
# 优化器,损失函数,平均holder,指标
optimizer = keras.optimizers.Nadam(learning_rate=0.01)
loss_fn = keras.losses.mean_squared_error
mean_loss = keras.metrics.Mean()
metrics = [keras.metrics.MeanAbsoluteError()]
循环¶
for epoch in range(1, n_epochs + 1):
print("Epoch {}/{}".format(epoch, n_epochs))
for step in range(1, n_steps + 1):
# get batch
X_batch, y_batch = random_batch(X_train, y_train)
# 计算损失,GradientTape()内自动微分
with tf.GradientTape() as tape:
y_pred = model(X_batch)
main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))
# 加上如正则化之类的结构损失
loss = tf.add_n([main_loss] + model.losses)
# 反向传播
gradients = tape.gradient(loss, model.trainable_variables)
# 梯度下降
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# 记录平均损失
mean_loss(loss)
# 记录指标
for metric in metrics:
metric(y_batch, y_pred)
# 打印损失和指标
print_status_bar(step * batch_size, len(y_train), mean_loss, metrics)
# epoch末打印
print_status_bar(len(y_train), len(y_train), mean_loss, metrics)
# 重置
for metric in [mean_loss] + metrics:
metric.reset_states()
Epoch 1/5
11610/11610 - mean: 1.5009 - mean_absolute_error: 0.9104
Epoch 2/5
11610/11610 - mean: 2.5215 - mean_absolute_error: 0.9303
Epoch 3/5
11610/11610 - mean: 1.6988 - mean_absolute_error: 0.9081
Epoch 4/5
11610/11610 - mean: 1.5036 - mean_absolute_error: 0.8840
Epoch 5/5
11610/11610 - mean: 1.3136 - mean_absolute_error: 0.8738