这篇教程tensorflow2实现resnet18和50网络,最少代码写得很实用,希望能帮到您。 这两个模型我已经在我的数据集(一个4分类的1600张118*118的关于上下左右箭头分类的图片上)试验过了,均可以达到100%的accuracy,因此,模型应该是没有问题的 好了,话不多说,代码奉上 resnet18:
import tensorflow as tf import numpy as np import os from tensorflow import keras from tensorflow.keras import layers,Sequential
class prepare(layers.Layer): def __init__(self): super(prepare, self).__init__() self.conv1=layers.Conv2D(64,(3,3),strides=1,padding="same") self.bn=layers.BatchNormalization() self.Relu=layers.Activation('relu') self.mp=layers.MaxPool2D(pool_size=(2,2),strides=2) def call(self,inputs): x=self.conv1(inputs) x=self.bn(x) x=self.Relu(x) x=self.mp(x) return x class BasicBlock(layers.Layer): def __init__(self,filter_num,stride=1): super(BasicBlock, self).__init__() self.conv1=layers.Conv2D(filter_num,(3,3),strides=stride,padding='same') self.bn1=layers.BatchNormalization() self.relu=layers.Activation('relu') self.conv2=layers.Conv2D(filter_num,(3,3),strides=1,padding='same') self.bn2 = layers.BatchNormalization()
if stride!=1: self.downsample=Sequential() self.downsample.add(layers.Conv2D(filter_num,(1,1),strides=stride)) else: self.downsample=lambda x:x def call(self,input,training=None): out=self.conv1(input) out=self.bn1(out) out=self.relu(out)
out=self.conv2(out) out=self.bn2(out)
identity=self.downsample(input) output=layers.add([out,identity]) output=tf.nn.relu(output) return output def get_model(num_classes): input_image = layers.Input(shape=(112, 112, 3), dtype="float32") output=prepare()(input_image) output=BasicBlock(64)(output) output=BasicBlock(64)(output) output=BasicBlock(128,2)(output) output=BasicBlock(128)(output) output=BasicBlock(256,2)(output) output=BasicBlock(256)(output) output=BasicBlock(512,2)(output) output=BasicBlock(512)(output) output=layers.GlobalAveragePooling2D()(output) output=layers.Dense(num_classes)(output) output-layers.Activation('relu')(output) return keras.Model(inputs=input_image, outputs=output)
resnet50:
import tensorflow as tf import numpy as np import os from tensorflow import keras from tensorflow.keras import layers,Sequential
class prepare(layers.Layer): def __init__(self): super(prepare, self).__init__() self.conv1=layers.Conv2D(64,(3,3),strides=1,padding="same") self.bn=layers.BatchNormalization() self.Relu=layers.Activation('relu') self.mp=layers.MaxPool2D(pool_size=(2,2),strides=2) def call(self,inputs): x=self.conv1(inputs) x=self.bn(x) x=self.Relu(x) x=self.mp(x) return x class block(layers.Layer): def __init__(self,filter_num,stride=1,is_first=False): super(block,self).__init__() self.conv1=layers.Conv2D(filter_num,(1,1),strides=1) self.bn1=layers.BatchNormalization() self.conv2=layers.Conv2D(filter_num,(3,3),strides=stride,padding='same') self.bn2=layers.BatchNormalization() self.conv3=layers.Conv2D(filter_num*4,(1,1),strides=1) self.bn3=layers.BatchNormalization() self.relu=layers.Activation('relu') if stride!=1 or is_first==True: self.downsample=Sequential() self.downsample.add(layers.Conv2D(filter_num*4,(1,1),strides=stride)) else: self.downsample=lambda x:x def call(self,inputs): x=self.conv1(inputs) x=self.bn1(x) x=self.relu(x) x=self.conv2(x) x=self.bn2(x) x=self.relu(x) x=self.conv3(x) x=self.bn3(x) identity=self.downsample(inputs) output=layers.add([x,identity]) output=tf.nn.relu(output) return output def get_model(num_classes): input_image = layers.Input(shape=(112, 112, 3), dtype="float32") out=prepare()(input_image) out=block(64,is_first=True)(out) out=block(64)(out) out=block(64)(out) out=block(128,stride=2)(out) out=block(128)(out) out=block(128)(out) out=block(256,stride=2)(out) out=block(256)(out) out=block(256)(out) out=block(512,stride=2)(out) out=block(512)(out) out=block(512)(out) out=layers.GlobalAveragePooling2D()(out) out=layers.Dense(num_classes)(out) out-layers.Activation('relu')(out) return keras.Model(inputs=input_image, outputs=out)
需要注意的是,prepare类是对数据的预处理,论文是先使用7x7的过滤器的,我这里使用的是3x3,因为我输入的图片是
112*112大小的,如果需要改的话改prepare就行了
以上就是resnet的tensorflow1实现,代码是不是可以接受呢,是不是能看懂呢,
只要输入model=get_model(num_classes)就能得到模型拉,num_calsses是你分类的数量
之后定义好损失函数,使用fit训练就可以了, who is the best in CIFAR-10 ? 基于Keras的ResNet实现 |