这篇教程keras 多输入输出写得很实用,希望能帮到您。
A toy ResNet model
In addition to models with multiple inputs and outputs, the functional API makes it easy to manipulate non-linear connectivity topologies -- these are models with layers that are not connected sequentially, which the Sequential API cannot handle.
A common use case for this is residual connections. Let's build a toy ResNet model for CIFAR10 to demonstrate this:
inputs = keras.Input(shape=(32, 32, 3), name="img")
x = layers.Conv2D(32, 3, activation="relu")(inputs)
x = layers.Conv2D(64, 3, activation="relu")(x)
block_1_output = layers.MaxPooling2D(3)(x)
x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_1_output)
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
block_2_output = layers.add([x, block_1_output])
x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_2_output)
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
block_3_output = layers.add([x, block_2_output])
x = layers.Conv2D(64, 3, activation="relu")(block_3_output)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(10)(x)
model = keras.Model(inputs, outputs, name="toy_resnet")
model.summary()
Model: "toy_resnet"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
img (InputLayer) [(None, 32, 32, 3)] 0
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 30, 30, 32) 896 img[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 28, 28, 64) 18496 conv2d_8[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 9, 9, 64) 0 conv2d_9[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 9, 9, 64) 36928 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 9, 9, 64) 36928 conv2d_10[0][0]
__________________________________________________________________________________________________
add (Add) (None, 9, 9, 64) 0 conv2d_11[0][0]
max_pooling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 9, 9, 64) 36928 add[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 9, 9, 64) 36928 conv2d_12[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 9, 9, 64) 0 conv2d_13[0][0]
add[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 7, 7, 64) 36928 add_1[0][0]
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 64) 0 conv2d_14[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 256) 16640 global_average_pooling2d[0][0]
__________________________________________________________________________________________________
dropout (Dropout) (None, 256) 0 dense_6[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 10) 2570 dropout[0][0]
==================================================================================================
Total params: 223,242
Trainable params: 223,242
Non-trainable params: 0
__________________________________________________________________________________________________
Plot the model:
keras.utils.plot_model(model, "mini_resnet.png", show_shapes=True)
Now train the model:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
model.compile(
optimizer=keras.optimizers.RMSprop(1e-3),
loss=keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=["acc"],
)
# We restrict the data to the first 1000 samples so as to limit execution time
# on Colab. Try to train on the entire dataset until convergence!
model.fit(x_train[:1000], y_train[:1000], batch_size=64, epochs=1, validation_split=0.2)
13/13 [==============================] - 2s 87ms/step - loss: 2.3145 - acc: 0.1124 - val_loss: 2.3046 - val_acc: 0.1150
<tensorflow.python.keras.callbacks.History at 0x147c01650>
keras创建复杂模型 keras模型集成 |