V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
fanqieipnet
V2EX  ›  推广

用什么方法可以实现深度残差网络?

  •  
  •   fanqieipnet · 2020-12-01 18:01:54 +08:00 · 508 次点击
    这是一个创建于 1261 天前的主题,其中的信息可能已经有所发展或是发生改变。
    AlexNet,VGG,GoogLeNet 等网络模型的出现将神经网络的发展带入了几十层的阶段,研究人员发现网络的层数越深,越有可能获得更好的泛化能力。但是当模型加深以后,网络变得越来越难训练,这主要是由于梯度弥散现象造成的。在较深层数的神经网络中间,梯度信息由网络的末层逐层传向网络的首层时,传递的过程中会出现梯度接近于 0 的现象。网络层数越深,梯度弥散现象可能会越严重。用什么方法可以实现深度残差网络?今天番茄加速就来分析一下。

      为了解决这个问题,研究人员尝试给深层神经网络添加一种回退到浅层神经网络的机制。当深层神经网络可以轻松地回退到浅层神经网络时,深层神经网络可以获得与浅层神经网络相当的模型性能,而不至于更糟糕。

      通过在输入和输出之间添加一条直接连接的 Skip Connection 可以让神经网络具有回退的能力。以 VGG13 深度神经网络为例,假设观察到 VGG13 模型出现梯度弥散现象,而 10 层的网络模型并没有观测到梯度弥散现象,那么可以考虑在最后的两个卷积层添加 Skip Connection,通过这种方式网络模型可以自动选择是否经由这两个卷积层完成特征变换,还是直接跳过这两个卷积层而选择 Skip Connection,亦或结合两个卷积层和 Skip Connection 的输出。

       ResNet 通过在卷积层的输入和输出之间添加 Skip Connection 实现层数回退机制,输入𝑥通过两个卷积层,得到特征变换后的输出ℱ(𝑥),与输入𝑥进行对应元素的相加运算,得到最终输出:H(x) = x + f(x)

       ResNet 实现

       1.定义残差模块

      首先创建一个新类,在初始化阶段创建残差块中需要的卷积层,激活函数层等,首先新建 f(𝑥)卷积层:

       import tensorflow as tf

       from tensorflow import keras

       from tensorflow.keras import layers, Sequential

       class BasicBlock(layers.Layer):

      # 残差模块

       def __init__(self, filter_num, stride=1):

       super(BasicBlock, self).__init__()

      # 第一个卷积单元

       self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=stride, padding='same')

       self.bn1 = layers.BatchNormalization()

       self.relu = layers.Activation('relu')

      # 第二个卷积单元

       self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')

       self.bn2 = layers.BatchNormalization()

       if stride != 1:# 通过 1x1 卷积完成 shape 匹配

       self.downsample = Sequential()

       self.downsample.add(layers.Conv2D(filter_num, (1, 1), strides=stride))

       else:# shape 匹配,直接连接

       self.downsample = lambda x:x

       def call(self, inputs, training=None):

      # [b, h, w, c],通过第一个卷积单元

       out = self.conv1(inputs)

       out = self.bn1(out)

       out = self.relu(out)

      # 通过第二个卷积单元

       out = self.conv2(out)

       out = self.bn2(out)

      # 通过 identity 模块

       identity = self.downsample(inputs)

      # 2 条路径输出直接相加

       output = layers.add([out, identity])

       output = tf.nn.relu(output) # 激活函数

       return output

       2.实现 ResNet 类

       class ResNet(keras.Model):

      # 通用的 ResNet 实现类

       def __init__(self, layer_dims, num_classes=10): # [2, 2, 2, 2]

       super(ResNet, self).__init__()

      # 根网络,预处理

       self.stem = Sequential([layers.Conv2D(64, (3, 3), strides=(1, 1)),

       layers.BatchNormalization(),

       layers.Activation('relu'),

       layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')

      ])

      # 堆叠 4 个 Block,每个 block 包含了多个 BasicBlock,设置步长不一样

       self.layer1 = self.build_resblock(64, layer_dims[0])

       self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)

       self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)

       self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)

      # 通过 Pooling 层将高宽降低为 1x1

       self.avgpool = layers.GlobalAveragePooling2D()

      # 最后连接一个全连接层分类

       self.fc = layers.Dense(num_classes)

       def call(self, inputs, training=None):

      # 通过根网络

       x = self.stem(inputs)

      # 一次通过 4 个模块

       x = self.layer1(x)

       x = self.layer2(x)

       x = self.layer3(x)

       x = self.layer4(x)

      # 通过池化层

       x = self.avgpool(x)

      # 通过全连接层

       x = self.fc(x)

       return x

       def build_resblock(self, filter_num, blocks, stride=1):

      # 辅助函数,堆叠 filter_num 个 BasicBlock

       res_blocks = Sequential()

      # 只有第一个 BasicBlock 的步长可能不为 1,实现下采样

       res_blocks.add(BasicBlock(filter_num, stride))

       for _ in range(1, blocks):#其他 BasicBlock 步长都为 1

       res_blocks.add(BasicBlock(filter_num, stride=1))

       return res_blocks

      上面就是用 Tensorflow 实现深度残差网络的原理和方法,欢迎指正。
    目前尚无回复
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   实用小工具   ·   5732 人在线   最高记录 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 28ms · UTC 07:00 · PVG 15:00 · LAX 00:00 · JFK 03:00
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.