表白制作網(wǎng)站網(wǎng)站設計公司
【圖像分類】【深度學習】【Pytorch版本】Inception-ResNet模型算法詳解
文章目錄
- 【圖像分類】【深度學習】【Pytorch版本】Inception-ResNet模型算法詳解
- 前言
- Inception-ResNet講解
- Inception-ResNet-V1
- Inception-ResNet-V2
- 殘差模塊的縮放(Scaling of the Residuals)
- Inception-ResNet的總體模型結構
- GoogLeNet(Inception-ResNet) Pytorch代碼
- Inception-ResNet-V1
- Inception-ResNet-V2
- 完整代碼
- Inception-ResNet-V1
- Inception-ResNet-V2
- 總結
前言
GoogLeNet(Inception-ResNet)是由谷歌的Szegedy, Christian等人在《Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning【AAAI-2017】》【論文地址】一文中提出的改進模型,受啟發(fā)于ResNet【參考】在深度網(wǎng)絡上較好的表現(xiàn)影響,論文將殘差連接加入到Inception結構中形成2個Inception-ResNet版本的網(wǎng)絡,它將殘差連接取代原本Inception塊中池化層部分,并將拼接變成了求和相加,提升了Inception的訓練速度。
因為InceptionV4、Inception-Resnet-v1和Inception-Resnet-v2同出自一篇論文,大部分讀者對InceptionV4存在誤解,認為它是Inception模塊與殘差學習的結合,其實InceptionV4沒有使用殘差學習的思想,它基本延續(xù)了Inception v2/v3的結構,只有Inception-Resnet-v1和Inception-Resnet-v2才是Inception模塊與殘差學習的結合產(chǎn)物。
Inception-ResNet講解
Inception-ResNet的核心思想是將Inception模塊和ResNet模塊進行融合,以利用它們各自的優(yōu)點。Inception模塊通過并行多個不同大小的卷積核來捕捉多尺度的特征,而ResNet模塊通過殘差連接解決了深層網(wǎng)絡中的梯度消失和梯度爆炸問題,有助于更好地訓練深層模型。Inception-ResNet使用了與InceptionV4【參考】類似的Inception模塊,并在其中引入了ResNet的殘差連接。這樣,網(wǎng)絡中的每個Inception模塊都包含了兩個分支:一個是常規(guī)的Inception結構,另一個是包含殘差連接的Inception結構。這種設計使得模型可以更好地學習特征表示,并且在訓練過程中可以更有效地傳播梯度。
Inception-ResNet-V1
Inception-ResNet-v1:一種和InceptionV3【參考】具有相同計算損耗的結構。
-
Stem結構: Inception-ResNet-V1的Stem結構類似于此前的InceptionV3網(wǎng)絡中Inception結構組之前的網(wǎng)絡層。
所有卷積中沒有標記為V表示填充方式為"SAME Padding",輸入和輸出維度一致;標記為V表示填充方式為"VALID Padding",輸出維度視具體情況而定。
-
Inception-resnet-A結構: InceptionV4網(wǎng)絡中Inception-A結構的變體,1×1卷積的目的是為了保持主分支與shortcut分支的特征圖形狀保持完全一致。
Inception-resnet結構殘差連接代替了Inception中的池化層,并用殘差連接相加操作取代了原Inception塊中的拼接操作。
-
Inception-resnet-B結構: InceptionV4網(wǎng)絡中Inception-B結構的變體,1×1卷積的目的是為了保持主分支與shortcut分支的特征圖形狀保持完全一致。
-
Inception-resnet-C結構: InceptionV4網(wǎng)絡中Inception-C結構的變體,1×1卷積的目的是為了保持主分支與shortcut分支的特征圖形狀保持完全一致。
-
Redution-A結構: 與InceptionV4網(wǎng)絡中Redution-A結構一致,區(qū)別在于卷積核的個數(shù)。
k和l表示卷積個數(shù),不同網(wǎng)絡結構的redution-A結構k和l是不同的。
-
Redution-B結構:
.
Inception-ResNet-V2
Inception-ResNet-v2:這是一種和InceptionV4具有相同計算損耗的結構,但是訓練速度要比純Inception-v4要快。
Inception-ResNet-v2的整體框架和Inception-ResNet-v1的一致,除了Inception-ResNet-v2的stem結構與Inception V4的相同,其他的的結構Inception-ResNet-v2與Inception-ResNet-v1的類似,只不過卷積的個數(shù)Inception-ResNet-v2數(shù)量更多。
- Stem結構: Inception-ResNet-v2的stem結構與Inception V4的相同。
- Inception-resnet-A結構: InceptionV4網(wǎng)絡中Inception-A結構的變體,1×1卷積的目的是為了保持主分支與shortcut分支的特征圖形狀保持完全一致。
- Inception-resnet-B結構: InceptionV4網(wǎng)絡中Inception-B結構的變體,1×1卷積的目的是為了保持主分支與shortcut分支的特征圖形狀保持完全一致。
- Inception-resnet-C結構: InceptionV4網(wǎng)絡中Inception-C結構的變體,1×1卷積的目的是為了保持主分支與shortcut分支的特征圖形狀保持完全一致。
- Redution-A結構: 與InceptionV4網(wǎng)絡中Redution-A結構一致,區(qū)別在于卷積核的個數(shù)。
k和l表示卷積個數(shù),不同網(wǎng)絡結構的redution-A結構k和l是不同的。
- Redution-B結構:
- Redution-B結構:
殘差模塊的縮放(Scaling of the Residuals)
如果單個網(wǎng)絡層卷積核數(shù)量過多(超過1000),殘差網(wǎng)絡開始出現(xiàn)不穩(wěn)定,網(wǎng)絡會在訓練過程早期便會開始失效—經(jīng)過幾萬次訓練后,平均池化層之前的層開始只輸出0。降低學習率、增加額外的BN層都無法避免這種狀況。因此在將shortcut分支加到當前殘差塊的輸出之前,對殘差塊的輸出進行放縮能夠穩(wěn)定訓練
通常,將殘差放縮因子定在0.1到0.3之間去縮放殘差塊輸出。即使縮放并不是完全必須的,它似乎并不會影響最終的準確率,但是放縮能有益于訓練的穩(wěn)定性。
Inception-ResNet的總體模型結構
下圖是原論文給出的關于 Inception-ResNet-V1模型結構的詳細示意圖:
下圖是原論文給出的關于 Inception-ResNet-V2模型結構的詳細示意圖:
讀者注意了,原始論文標注的 Inception-ResNet-V2通道數(shù)有一部分是錯的,寫代碼時候?qū)簧稀?/p>
兩個版本的總體結構相同,具體的Stem、Inception塊、Redution塊則稍微不同。
Inception-ResNet-V1和 Inception-ResNet-V2在圖像分類中分為兩部分:backbone部分: 主要由 Inception-resnet模塊、Stem模塊和池化層(匯聚層)組成,分類器部分:由全連接層組成。
GoogLeNet(Inception-ResNet) Pytorch代碼
Inception-ResNet-V1
卷積層組: 卷積層+BN層+激活函數(shù)
# 卷積組: Conv2d+BN+ReLU
class BasicConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)self.bn = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return x
Stem模塊: 卷積層組+池化層
# Stem:BasicConv2d+MaxPool2d
class Stem(nn.Module):def __init__(self, in_channels):super(Stem, self).__init__()# conv3x3(32 stride2 valid)self.conv1 = BasicConv2d(in_channels, 32, kernel_size=3, stride=2)# conv3*3(32 valid)self.conv2 = BasicConv2d(32, 32, kernel_size=3)# conv3*3(64)self.conv3 = BasicConv2d(32, 64, kernel_size=3, padding=1)# maxpool3*3(stride2 valid)self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2)# conv1*1(80)self.conv5 = BasicConv2d(64, 80, kernel_size=1)# conv3*3(192 valid)self.conv6 = BasicConv2d(80, 192, kernel_size=1)# conv3*3(256 stride2 valid)self.conv7 = BasicConv2d(192, 256, kernel_size=3, stride=2)def forward(self, x):x = self.maxpool4(self.conv3(self.conv2(self.conv1(x))))x = self.conv7(self.conv6(self.conv5(x)))return x
Inception_ResNet-A模塊: 卷積層組+池化層
# Inception_ResNet_A:BasicConv2d+MaxPool2d
class Inception_ResNet_A(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext, scale=1.0):super(Inception_ResNet_A, self).__init__()# 縮減指數(shù)self.scale = scale# conv1*1(32)self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)# conv1*1(32)+conv3*3(32)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch3x3red, 1),BasicConv2d(ch3x3red, ch3x3, 3, stride=1, padding=1))# conv1*1(32)+conv3*3(32)+conv3*3(32)self.branch_2 = nn.Sequential(BasicConv2d(in_channels, ch3x3redX2, 1),BasicConv2d(ch3x3redX2, ch3x3X2_1, 3, stride=1, padding=1),BasicConv2d(ch3x3X2_1, ch3x3X2_2, 3, stride=1, padding=1))# conv1*1(256)self.conv = BasicConv2d(ch1x1+ch3x3+ch3x3X2_2, ch1x1ext, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)x2 = self.branch_2(x)# 拼接x_res = torch.cat((x0, x1, x2), dim=1)x_res = self.conv(x_res)return self.relu(x + self.scale * x_res)
Inception_ResNet-B模塊: 卷積層組+池化層
# Inception_ResNet_B:BasicConv2d+MaxPool2d
class Inception_ResNet_B(nn.Module):def __init__(self, in_channels, ch1x1, ch_red, ch_1, ch_2, ch1x1ext, scale=1.0):super(Inception_ResNet_B, self).__init__()# 縮減指數(shù)self.scale = scale# conv1*1(128)self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)# conv1*1(128)+conv1*7(128)+conv1*7(128)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch_red, 1),BasicConv2d(ch_red, ch_1, (1, 7), stride=1, padding=(0, 3)),BasicConv2d(ch_1, ch_2, (7, 1), stride=1, padding=(3, 0)))# conv1*1(896)self.conv = BasicConv2d(ch1x1+ch_2, ch1x1ext, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)# 拼接x_res = torch.cat((x0, x1), dim=1)x_res = self.conv(x_res)return self.relu(x + self.scale * x_res)
Inception_ResNet-C模塊: 卷積層組+池化層
# Inception_ResNet_C:BasicConv2d+MaxPool2d
class Inception_ResNet_C(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext, scale=1.0, activation=True):super(Inception_ResNet_C, self).__init__()# 縮減指數(shù)self.scale = scale# 是否激活self.activation = activation# conv1*1(192)self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)# conv1*1(192)+conv1*3(192)+conv3*1(192)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch3x3redX2, 1),BasicConv2d(ch3x3redX2, ch3x3X2_1, (1, 3), stride=1, padding=(0, 1)),BasicConv2d(ch3x3X2_1, ch3x3X2_2, (3, 1), stride=1, padding=(1, 0)))# conv1*1(1792)self.conv = BasicConv2d(ch1x1+ch3x3X2_2, ch1x1ext, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)# 拼接x_res = torch.cat((x0, x1), dim=1)x_res = self.conv(x_res)if self.activation:return self.relu(x + self.scale * x_res)return x + self.scale * x_res
redutionA模塊: 卷積層組+池化層
# redutionA:BasicConv2d+MaxPool2d
class redutionA(nn.Module):def __init__(self, in_channels, k, l, m, n):super(redutionA, self).__init__()# conv3*3(n stride2 valid)self.branch1 = nn.Sequential(BasicConv2d(in_channels, n, kernel_size=3, stride=2),)# conv1*1(k)+conv3*3(l)+conv3*3(m stride2 valid)self.branch2 = nn.Sequential(BasicConv2d(in_channels, k, kernel_size=1),BasicConv2d(k, l, kernel_size=3, padding=1),BasicConv2d(l, m, kernel_size=3, stride=2))# maxpool3*3(stride2 valid)self.branch3 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2))def forward(self, x):branch1 = self.branch1(x)branch2 = self.branch2(x)branch3 = self.branch3(x)# 拼接outputs = [branch1, branch2, branch3]return torch.cat(outputs, 1)
redutionB模塊: 卷積層組+池化層
# redutionB:BasicConv2d+MaxPool2d
class redutionB(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3_1, ch3x3_2, ch3x3_3, ch3x3_4):super(redutionB, self).__init__()# conv1*1(256)+conv3x3(384 stride2 valid)self.branch_0 = nn.Sequential(BasicConv2d(in_channels, ch1x1, 1),BasicConv2d(ch1x1, ch3x3_1, 3, stride=2, padding=0))# conv1*1(256)+conv3x3(256 stride2 valid)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch1x1, 1),BasicConv2d(ch1x1, ch3x3_2, 3, stride=2, padding=0),)# conv1*1(256)+conv3x3(256)+conv3x3(256 stride2 valid)self.branch_2 = nn.Sequential(BasicConv2d(in_channels, ch1x1, 1),BasicConv2d(ch1x1, ch3x3_3, 3, stride=1, padding=1),BasicConv2d(ch3x3_3, ch3x3_4, 3, stride=2, padding=0))# maxpool3*3(stride2 valid)self.branch_3 = nn.MaxPool2d(3, stride=2, padding=0)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)x2 = self.branch_2(x)x3 = self.branch_3(x)return torch.cat((x0, x1, x2, x3), dim=1)
Inception-ResNet-V2
Inception-ResNet-V2除了Stem,其他模塊在結構上與Inception-ResNet-V1一致。
卷積層組: 卷積層+BN層+激活函數(shù)
# 卷積組: Conv2d+BN+ReLU
class BasicConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)self.bn = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return x
Stem模塊: 卷積層組+池化層
# Stem:BasicConv2d+MaxPool2d
class Stem(nn.Module):def __init__(self, in_channels):super(Stem, self).__init__()# conv3*3(32 stride2 valid)self.conv1 = BasicConv2d(in_channels, 32, kernel_size=3, stride=2)# conv3*3(32 valid)self.conv2 = BasicConv2d(32, 32, kernel_size=3)# conv3*3(64)self.conv3 = BasicConv2d(32, 64, kernel_size=3, padding=1)# maxpool3*3(stride2 valid) & conv3*3(96 stride2 valid)self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2)self.conv4 = BasicConv2d(64, 96, kernel_size=3, stride=2)# conv1*1(64)+conv3*3(96 valid)self.conv5_1_1 = BasicConv2d(160, 64, kernel_size=1)self.conv5_1_2 = BasicConv2d(64, 96, kernel_size=3)# conv1*1(64)+conv7*1(64)+conv1*7(64)+conv3*3(96 valid)self.conv5_2_1 = BasicConv2d(160, 64, kernel_size=1)self.conv5_2_2 = BasicConv2d(64, 64, kernel_size=(7, 1), padding=(3, 0))self.conv5_2_3 = BasicConv2d(64, 64, kernel_size=(1, 7), padding=(0, 3))self.conv5_2_4 = BasicConv2d(64, 96, kernel_size=3)# conv3*3(192 valid) & maxpool3*3(stride2 valid)self.conv6 = BasicConv2d(192, 192, kernel_size=3, stride=2)self.maxpool6 = nn.MaxPool2d(kernel_size=3, stride=2)def forward(self, x):x1_1 = self.maxpool4(self.conv3(self.conv2(self.conv1(x))))x1_2 = self.conv4(self.conv3(self.conv2(self.conv1(x))))x1 = torch.cat([x1_1, x1_2], 1)x2_1 = self.conv5_1_2(self.conv5_1_1(x1))x2_2 = self.conv5_2_4(self.conv5_2_3(self.conv5_2_2(self.conv5_2_1(x1))))x2 = torch.cat([x2_1, x2_2], 1)x3_1 = self.conv6(x2)x3_2 = self.maxpool6(x2)x3 = torch.cat([x3_1, x3_2], 1)return x3
Inception_ResNet-A模塊: 卷積層組+池化層
# Inception_ResNet_A:BasicConv2d+MaxPool2d
class Inception_ResNet_A(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext, scale=1.0):super(Inception_ResNet_A, self).__init__()# 縮減指數(shù)self.scale = scale# conv1*1(32)self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)# conv1*1(32)+conv3*3(32)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch3x3red, 1),BasicConv2d(ch3x3red, ch3x3, 3, stride=1, padding=1))# conv1*1(32)+conv3*3(48)+conv3*3(64)self.branch_2 = nn.Sequential(BasicConv2d(in_channels, ch3x3redX2, 1),BasicConv2d(ch3x3redX2, ch3x3X2_1, 3, stride=1, padding=1),BasicConv2d(ch3x3X2_1, ch3x3X2_2, 3, stride=1, padding=1))# conv1*1(384)self.conv = BasicConv2d(ch1x1+ch3x3+ch3x3X2_2, ch1x1ext, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)x2 = self.branch_2(x)# 拼接x_res = torch.cat((x0, x1, x2), dim=1)x_res = self.conv(x_res)return self.relu(x + self.scale * x_res)
Inception_ResNet-B模塊: 卷積層組+池化層
# Inception_ResNet_B:BasicConv2d+MaxPool2d
class Inception_ResNet_B(nn.Module):def __init__(self, in_channels, ch1x1, ch_red, ch_1, ch_2, ch1x1ext, scale=1.0):super(Inception_ResNet_B, self).__init__()# 縮減指數(shù)self.scale = scale# conv1*1(192)self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)# conv1*1(128)+conv1*7(160)+conv1*7(192)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch_red, 1),BasicConv2d(ch_red, ch_1, (1, 7), stride=1, padding=(0, 3)),BasicConv2d(ch_1, ch_2, (7, 1), stride=1, padding=(3, 0)))# conv1*1(1154)self.conv = BasicConv2d(ch1x1+ch_2, ch1x1ext, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)# 拼接x_res = torch.cat((x0, x1), dim=1)x_res = self.conv(x_res)return self.relu(x + self.scale * x_res)
Inception_ResNet-C模塊: 卷積層組+池化層
# Inception_ResNet_C:BasicConv2d+MaxPool2d
class Inception_ResNet_C(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext, scale=1.0, activation=True):super(Inception_ResNet_C, self).__init__()# 縮減指數(shù)self.scale = scale# 是否激活self.activation = activation# conv1*1(192)self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)# conv1*1(192)+conv1*3(224)+conv3*1(256)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch3x3redX2, 1),BasicConv2d(ch3x3redX2, ch3x3X2_1, (1, 3), stride=1, padding=(0, 1)),BasicConv2d(ch3x3X2_1, ch3x3X2_2, (3, 1), stride=1, padding=(1, 0)))# conv1*1(2048)self.conv = BasicConv2d(ch1x1+ch3x3X2_2, ch1x1ext, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)# 拼接x_res = torch.cat((x0, x1), dim=1)x_res = self.conv(x_res)if self.activation:return self.relu(x + self.scale * x_res)return x + self.scale * x_res
redutionA模塊: 卷積層組+池化層
# redutionA:BasicConv2d+MaxPool2d
class redutionA(nn.Module):def __init__(self, in_channels, k, l, m, n):super(redutionA, self).__init__()# conv3*3(n stride2 valid)self.branch1 = nn.Sequential(BasicConv2d(in_channels, n, kernel_size=3, stride=2),)# conv1*1(k)+conv3*3(l)+conv3*3(m stride2 valid)self.branch2 = nn.Sequential(BasicConv2d(in_channels, k, kernel_size=1),BasicConv2d(k, l, kernel_size=3, padding=1),BasicConv2d(l, m, kernel_size=3, stride=2))# maxpool3*3(stride2 valid)self.branch3 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2))def forward(self, x):branch1 = self.branch1(x)branch2 = self.branch2(x)branch3 = self.branch3(x)# 拼接outputs = [branch1, branch2, branch3]return torch.cat(outputs, 1)
redutionB模塊: 卷積層組+池化層
# redutionB:BasicConv2d+MaxPool2d
class redutionB(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3_1, ch3x3_2, ch3x3_3, ch3x3_4):super(redutionB, self).__init__()# conv1*1(256)+conv3x3(384 stride2 valid)self.branch_0 = nn.Sequential(BasicConv2d(in_channels, ch1x1, 1),BasicConv2d(ch1x1, ch3x3_1, 3, stride=2, padding=0))# conv1*1(256)+conv3x3(288 stride2 valid)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch1x1, 1),BasicConv2d(ch1x1, ch3x3_2, 3, stride=2, padding=0),)# conv1*1(256)+conv3x3(288)+conv3x3(320 stride2 valid)self.branch_2 = nn.Sequential(BasicConv2d(in_channels, ch1x1, 1),BasicConv2d(ch1x1, ch3x3_3, 3, stride=1, padding=1),BasicConv2d(ch3x3_3, ch3x3_4, 3, stride=2, padding=0))# maxpool3*3(stride2 valid)self.branch_3 = nn.MaxPool2d(3, stride=2, padding=0)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)x2 = self.branch_2(x)x3 = self.branch_3(x)return torch.cat((x0, x1, x2, x3), dim=1)
完整代碼
Inception-ResNet的輸入圖像尺寸是299×299
Inception-ResNet-V1
import torch
import torch.nn as nn
from torchsummary import summary# 卷積組: Conv2d+BN+ReLU
class BasicConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)self.bn = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return x# Stem:BasicConv2d+MaxPool2d
class Stem(nn.Module):def __init__(self, in_channels):super(Stem, self).__init__()# conv3x3(32 stride2 valid)self.conv1 = BasicConv2d(in_channels, 32, kernel_size=3, stride=2)# conv3*3(32 valid)self.conv2 = BasicConv2d(32, 32, kernel_size=3)# conv3*3(64)self.conv3 = BasicConv2d(32, 64, kernel_size=3, padding=1)# maxpool3*3(stride2 valid)self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2)# conv1*1(80)self.conv5 = BasicConv2d(64, 80, kernel_size=1)# conv3*3(192 valid)self.conv6 = BasicConv2d(80, 192, kernel_size=1)# conv3*3(256 stride2 valid)self.conv7 = BasicConv2d(192, 256, kernel_size=3, stride=2)def forward(self, x):x = self.maxpool4(self.conv3(self.conv2(self.conv1(x))))x = self.conv7(self.conv6(self.conv5(x)))return x# Inception_ResNet_A:BasicConv2d+MaxPool2d
class Inception_ResNet_A(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext, scale=1.0):super(Inception_ResNet_A, self).__init__()# 縮減指數(shù)self.scale = scale# conv1*1(32)self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)# conv1*1(32)+conv3*3(32)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch3x3red, 1),BasicConv2d(ch3x3red, ch3x3, 3, stride=1, padding=1))# conv1*1(32)+conv3*3(32)+conv3*3(32)self.branch_2 = nn.Sequential(BasicConv2d(in_channels, ch3x3redX2, 1),BasicConv2d(ch3x3redX2, ch3x3X2_1, 3, stride=1, padding=1),BasicConv2d(ch3x3X2_1, ch3x3X2_2, 3, stride=1, padding=1))# conv1*1(256)self.conv = BasicConv2d(ch1x1+ch3x3+ch3x3X2_2, ch1x1ext, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)x2 = self.branch_2(x)# 拼接x_res = torch.cat((x0, x1, x2), dim=1)x_res = self.conv(x_res)return self.relu(x + self.scale * x_res)# Inception_ResNet_B:BasicConv2d+MaxPool2d
class Inception_ResNet_B(nn.Module):def __init__(self, in_channels, ch1x1, ch_red, ch_1, ch_2, ch1x1ext, scale=1.0):super(Inception_ResNet_B, self).__init__()# 縮減指數(shù)self.scale = scale# conv1*1(128)self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)# conv1*1(128)+conv1*7(128)+conv1*7(128)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch_red, 1),BasicConv2d(ch_red, ch_1, (1, 7), stride=1, padding=(0, 3)),BasicConv2d(ch_1, ch_2, (7, 1), stride=1, padding=(3, 0)))# conv1*1(896)self.conv = BasicConv2d(ch1x1+ch_2, ch1x1ext, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)# 拼接x_res = torch.cat((x0, x1), dim=1)x_res = self.conv(x_res)return self.relu(x + self.scale * x_res)# Inception_ResNet_C:BasicConv2d+MaxPool2d
class Inception_ResNet_C(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext, scale=1.0, activation=True):super(Inception_ResNet_C, self).__init__()# 縮減指數(shù)self.scale = scale# 是否激活self.activation = activation# conv1*1(192)self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)# conv1*1(192)+conv1*3(192)+conv3*1(192)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch3x3redX2, 1),BasicConv2d(ch3x3redX2, ch3x3X2_1, (1, 3), stride=1, padding=(0, 1)),BasicConv2d(ch3x3X2_1, ch3x3X2_2, (3, 1), stride=1, padding=(1, 0)))# conv1*1(1792)self.conv = BasicConv2d(ch1x1+ch3x3X2_2, ch1x1ext, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)# 拼接x_res = torch.cat((x0, x1), dim=1)x_res = self.conv(x_res)if self.activation:return self.relu(x + self.scale * x_res)return x + self.scale * x_res# redutionA:BasicConv2d+MaxPool2d
class redutionA(nn.Module):def __init__(self, in_channels, k, l, m, n):super(redutionA, self).__init__()# conv3*3(n stride2 valid)self.branch1 = nn.Sequential(BasicConv2d(in_channels, n, kernel_size=3, stride=2),)# conv1*1(k)+conv3*3(l)+conv3*3(m stride2 valid)self.branch2 = nn.Sequential(BasicConv2d(in_channels, k, kernel_size=1),BasicConv2d(k, l, kernel_size=3, padding=1),BasicConv2d(l, m, kernel_size=3, stride=2))# maxpool3*3(stride2 valid)self.branch3 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2))def forward(self, x):branch1 = self.branch1(x)branch2 = self.branch2(x)branch3 = self.branch3(x)# 拼接outputs = [branch1, branch2, branch3]return torch.cat(outputs, 1)# redutionB:BasicConv2d+MaxPool2d
class redutionB(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3_1, ch3x3_2, ch3x3_3, ch3x3_4):super(redutionB, self).__init__()# conv1*1(256)+conv3x3(384 stride2 valid)self.branch_0 = nn.Sequential(BasicConv2d(in_channels, ch1x1, 1),BasicConv2d(ch1x1, ch3x3_1, 3, stride=2, padding=0))# conv1*1(256)+conv3x3(256 stride2 valid)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch1x1, 1),BasicConv2d(ch1x1, ch3x3_2, 3, stride=2, padding=0),)# conv1*1(256)+conv3x3(256)+conv3x3(256 stride2 valid)self.branch_2 = nn.Sequential(BasicConv2d(in_channels, ch1x1, 1),BasicConv2d(ch1x1, ch3x3_3, 3, stride=1, padding=1),BasicConv2d(ch3x3_3, ch3x3_4, 3, stride=2, padding=0))# maxpool3*3(stride2 valid)self.branch_3 = nn.MaxPool2d(3, stride=2, padding=0)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)x2 = self.branch_2(x)x3 = self.branch_3(x)return torch.cat((x0, x1, x2, x3), dim=1)class Inception_ResNetv1(nn.Module):def __init__(self, num_classes = 1000, k=192, l=192, m=256, n=384):super(Inception_ResNetv1, self).__init__()blocks = []blocks.append(Stem(3))for i in range(5):blocks.append(Inception_ResNet_A(256,32, 32, 32, 32, 32, 32, 256, 0.17))blocks.append(redutionA(256, k, l, m, n))for i in range(10):blocks.append(Inception_ResNet_B(896, 128, 128, 128, 128, 896, 0.10))blocks.append(redutionB(896,256, 384, 256, 256, 256))for i in range(4):blocks.append(Inception_ResNet_C(1792,192, 192, 192, 192, 1792, 0.20))blocks.append(Inception_ResNet_C(1792, 192, 192, 192, 192, 1792, activation=False))self.features = nn.Sequential(*blocks)self.conv = BasicConv2d(1792, 1536, 1)self.global_average_pooling = nn.AdaptiveAvgPool2d((1, 1))self.dropout = nn.Dropout(0.8)self.linear = nn.Linear(1536, num_classes)def forward(self, x):x = self.features(x)x = self.conv(x)x = self.global_average_pooling(x)x = x.view(x.size(0), -1)x = self.dropout(x)x = self.linear(x)return xif __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = Inception_ResNetv1().to(device)summary(model, input_size=(3, 229, 229))
summary可以打印網(wǎng)絡結構和參數(shù),方便查看搭建好的網(wǎng)絡結構。
Inception-ResNet-V2
import torch
import torch.nn as nn
from torchsummary import summary# 卷積組: Conv2d+BN+ReLU
class BasicConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)self.bn = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return x# Stem:BasicConv2d+MaxPool2d
class Stem(nn.Module):def __init__(self, in_channels):super(Stem, self).__init__()# conv3*3(32 stride2 valid)self.conv1 = BasicConv2d(in_channels, 32, kernel_size=3, stride=2)# conv3*3(32 valid)self.conv2 = BasicConv2d(32, 32, kernel_size=3)# conv3*3(64)self.conv3 = BasicConv2d(32, 64, kernel_size=3, padding=1)# maxpool3*3(stride2 valid) & conv3*3(96 stride2 valid)self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2)self.conv4 = BasicConv2d(64, 96, kernel_size=3, stride=2)# conv1*1(64)+conv3*3(96 valid)self.conv5_1_1 = BasicConv2d(160, 64, kernel_size=1)self.conv5_1_2 = BasicConv2d(64, 96, kernel_size=3)# conv1*1(64)+conv7*1(64)+conv1*7(64)+conv3*3(96 valid)self.conv5_2_1 = BasicConv2d(160, 64, kernel_size=1)self.conv5_2_2 = BasicConv2d(64, 64, kernel_size=(7, 1), padding=(3, 0))self.conv5_2_3 = BasicConv2d(64, 64, kernel_size=(1, 7), padding=(0, 3))self.conv5_2_4 = BasicConv2d(64, 96, kernel_size=3)# conv3*3(192 valid) & maxpool3*3(stride2 valid)self.conv6 = BasicConv2d(192, 192, kernel_size=3, stride=2)self.maxpool6 = nn.MaxPool2d(kernel_size=3, stride=2)def forward(self, x):x1_1 = self.maxpool4(self.conv3(self.conv2(self.conv1(x))))x1_2 = self.conv4(self.conv3(self.conv2(self.conv1(x))))x1 = torch.cat([x1_1, x1_2], 1)x2_1 = self.conv5_1_2(self.conv5_1_1(x1))x2_2 = self.conv5_2_4(self.conv5_2_3(self.conv5_2_2(self.conv5_2_1(x1))))x2 = torch.cat([x2_1, x2_2], 1)x3_1 = self.conv6(x2)x3_2 = self.maxpool6(x2)x3 = torch.cat([x3_1, x3_2], 1)return x3# Inception_ResNet_A:BasicConv2d+MaxPool2d
class Inception_ResNet_A(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext, scale=1.0):super(Inception_ResNet_A, self).__init__()# 縮減指數(shù)self.scale = scale# conv1*1(32)self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)# conv1*1(32)+conv3*3(32)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch3x3red, 1),BasicConv2d(ch3x3red, ch3x3, 3, stride=1, padding=1))# conv1*1(32)+conv3*3(48)+conv3*3(64)self.branch_2 = nn.Sequential(BasicConv2d(in_channels, ch3x3redX2, 1),BasicConv2d(ch3x3redX2, ch3x3X2_1, 3, stride=1, padding=1),BasicConv2d(ch3x3X2_1, ch3x3X2_2, 3, stride=1, padding=1))# conv1*1(384)self.conv = BasicConv2d(ch1x1+ch3x3+ch3x3X2_2, ch1x1ext, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)x2 = self.branch_2(x)# 拼接x_res = torch.cat((x0, x1, x2), dim=1)x_res = self.conv(x_res)return self.relu(x + self.scale * x_res)# Inception_ResNet_B:BasicConv2d+MaxPool2d
class Inception_ResNet_B(nn.Module):def __init__(self, in_channels, ch1x1, ch_red, ch_1, ch_2, ch1x1ext, scale=1.0):super(Inception_ResNet_B, self).__init__()# 縮減指數(shù)self.scale = scale# conv1*1(192)self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)# conv1*1(128)+conv1*7(160)+conv1*7(192)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch_red, 1),BasicConv2d(ch_red, ch_1, (1, 7), stride=1, padding=(0, 3)),BasicConv2d(ch_1, ch_2, (7, 1), stride=1, padding=(3, 0)))# conv1*1(1154)self.conv = BasicConv2d(ch1x1+ch_2, ch1x1ext, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)# 拼接x_res = torch.cat((x0, x1), dim=1)x_res = self.conv(x_res)return self.relu(x + self.scale * x_res)# Inception_ResNet_C:BasicConv2d+MaxPool2d
class Inception_ResNet_C(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3redX2, ch3x3X2_1, ch3x3X2_2, ch1x1ext, scale=1.0, activation=True):super(Inception_ResNet_C, self).__init__()# 縮減指數(shù)self.scale = scale# 是否激活self.activation = activation# conv1*1(192)self.branch_0 = BasicConv2d(in_channels, ch1x1, 1)# conv1*1(192)+conv1*3(224)+conv3*1(256)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch3x3redX2, 1),BasicConv2d(ch3x3redX2, ch3x3X2_1, (1, 3), stride=1, padding=(0, 1)),BasicConv2d(ch3x3X2_1, ch3x3X2_2, (3, 1), stride=1, padding=(1, 0)))# conv1*1(2048)self.conv = BasicConv2d(ch1x1+ch3x3X2_2, ch1x1ext, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)# 拼接x_res = torch.cat((x0, x1), dim=1)x_res = self.conv(x_res)if self.activation:return self.relu(x + self.scale * x_res)return x + self.scale * x_res# redutionA:BasicConv2d+MaxPool2d
class redutionA(nn.Module):def __init__(self, in_channels, k, l, m, n):super(redutionA, self).__init__()# conv3*3(n stride2 valid)self.branch1 = nn.Sequential(BasicConv2d(in_channels, n, kernel_size=3, stride=2),)# conv1*1(k)+conv3*3(l)+conv3*3(m stride2 valid)self.branch2 = nn.Sequential(BasicConv2d(in_channels, k, kernel_size=1),BasicConv2d(k, l, kernel_size=3, padding=1),BasicConv2d(l, m, kernel_size=3, stride=2))# maxpool3*3(stride2 valid)self.branch3 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2))def forward(self, x):branch1 = self.branch1(x)branch2 = self.branch2(x)branch3 = self.branch3(x)# 拼接outputs = [branch1, branch2, branch3]return torch.cat(outputs, 1)# redutionB:BasicConv2d+MaxPool2d
class redutionB(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3_1, ch3x3_2, ch3x3_3, ch3x3_4):super(redutionB, self).__init__()# conv1*1(256)+conv3x3(384 stride2 valid)self.branch_0 = nn.Sequential(BasicConv2d(in_channels, ch1x1, 1),BasicConv2d(ch1x1, ch3x3_1, 3, stride=2, padding=0))# conv1*1(256)+conv3x3(288 stride2 valid)self.branch_1 = nn.Sequential(BasicConv2d(in_channels, ch1x1, 1),BasicConv2d(ch1x1, ch3x3_2, 3, stride=2, padding=0),)# conv1*1(256)+conv3x3(288)+conv3x3(320 stride2 valid)self.branch_2 = nn.Sequential(BasicConv2d(in_channels, ch1x1, 1),BasicConv2d(ch1x1, ch3x3_3, 3, stride=1, padding=1),BasicConv2d(ch3x3_3, ch3x3_4, 3, stride=2, padding=0))# maxpool3*3(stride2 valid)self.branch_3 = nn.MaxPool2d(3, stride=2, padding=0)def forward(self, x):x0 = self.branch_0(x)x1 = self.branch_1(x)x2 = self.branch_2(x)x3 = self.branch_3(x)return torch.cat((x0, x1, x2, x3), dim=1)class Inception_ResNetv2(nn.Module):def __init__(self, num_classes = 1000, k=256, l=256, m=384, n=384):super(Inception_ResNetv2, self).__init__()blocks = []blocks.append(Stem(3))for i in range(5):blocks.append(Inception_ResNet_A(384,32, 32, 32, 32, 48, 64, 384, 0.17))blocks.append(redutionA(384, k, l, m, n))for i in range(10):blocks.append(Inception_ResNet_B(1152, 192, 128, 160, 192, 1152, 0.10))blocks.append(redutionB(1152, 256, 384, 288, 288, 320))for i in range(4):blocks.append(Inception_ResNet_C(2144,192, 192, 224, 256, 2144, 0.20))blocks.append(Inception_ResNet_C(2144, 192, 192, 224, 256, 2144, activation=False))self.features = nn.Sequential(*blocks)self.conv = BasicConv2d(2144, 1536, 1)self.global_average_pooling = nn.AdaptiveAvgPool2d((1, 1))self.dropout = nn.Dropout(0.8)self.linear = nn.Linear(1536, num_classes)def forward(self, x):x = self.features(x)x = self.conv(x)x = self.global_average_pooling(x)x = x.view(x.size(0), -1)x = self.dropout(x)x = self.linear(x)return xif __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = Inception_ResNetv2().to(device)summary(model, input_size=(3, 229, 229))
summary可以打印網(wǎng)絡結構和參數(shù),方便查看搭建好的網(wǎng)絡結構。
總結
盡可能簡單、詳細的介紹了Inception-ResNet將Inception和ResNet結合的作用和過程,講解了Inception-ResNet模型的結構和pytorch代碼。