自定义模型

Note

自定义模型就像自定义层,通过继承keras.Model来实现我们想要的功能。

残差网络

我们通过实现残差网络来进行阐述如何自定义模型。

keras.Model是keras.layers.Layer的子类,因此可以像自定义层一样定义和使用模型。

但是模型有些额外的功能,包括compile()、fit()、evaluate()、predict()方法。

from tensorflow import keras

class ResidualBlock(keras.layers.Layer):
    # 自定义残差块
    def __init__(self, n_layers, n_neurons, **kwargs):
        # kwargs handles standard args (e.g., name)
        super().__init__(**kwargs)
        self.hidden = [keras.layers.Dense(n_neurons, activation='relu') 
                       for _ in range(n_layers)]
        
    def call(self, inputs):
        Z = inputs
        for layer in self.hidden:
            Z = layer(Z)
        return inputs + Z
class ResidualRegressor(keras.models.Model):
    # 自定义残差网络
    def __init__(self, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.hidden1 = keras.layers.Dense(30, activation="relu")
        # 两个残差块
        self.block1 = ResidualBlock(2, 30)
        self.block2 = ResidualBlock(2, 30)
        self.out = keras.layers.Dense(output_dim)

    def call(self, inputs):
        Z = self.hidden1(inputs)
        # 重复残差两次
        for _ in range(2):
            Z = self.block1(Z)
        Z = self.block2(Z)
        return self.out(Z)

Note

同自定义层一样,与input_shape相关的initialization需在build()中实现