import torch
import torch.nn as nn
class ResNet50(nn.Module):
def __init__(self, output_dim):
super().__init__()
self.conv1 = nn.Conv2d(1, 64,
kernel_size=(7, 7),
stride=(2, 2),
padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=(3, 3),
stride=(2, 2),
padding=1)
# Block 1
self.block0 = self._building_block(256, channel_in=64)
self.block1 = nn.ModuleList([
self._building_block(256) for _ in range(2)
])
self.conv2 = nn.Conv2d(256, 512,
kernel_size=(1, 1),
stride=(2, 2))
# Block 2
self.block2 = nn.ModuleList([
self._building_block(512) for _ in range(4)
])
self.conv3 = nn.Conv2d(512, 1024,
kernel_size=(1, 1),
stride=(2, 2))
# Block 3
self.block3 = nn.ModuleList([
self._building_block(1024) for _ in range(6)
])
self.conv4 = nn.Conv2d(1024, 2048,
kernel_size=(1, 1),
stride=(2, 2))
# Block 4
self.block4 = nn.ModuleList([
self._building_block(2048) for _ in range(3)
])
self.avg_pool = GlobalAvgPool2d() # TODO: GlobalAvgPool2d
self.fc = nn.Linear(2048, 1000)
self.out = nn.Linear(1000, output_dim)
def forward(self, x):
h = self.conv1(x)
h = self.bn1(h)
h = self.relu1(h)
h = self.pool1(h)
h = self.block0(h)
for block in self.block1:
h = block(h)
h = self.conv2(h)
for block in self.block2:
h = block(h)
h = self.conv3(h)
for block in self.block3:
h = block(h)
h = self.conv4(h)
for block in self.block4:
h = block(h)
h = self.avg_pool(h)
h = self.fc(h)
h = torch.relu(h)
h = self.out(h)
y = torch.log_softmax(h, dim=-1)
return y
def _building_block(self,
channel_out,
channel_in=None):
if channel_in is None:
channel_in = channel_out
return Block(channel_in, channel_out)
ResNet50では、最後の全結合層に接続する際に、global average poolingを行います。これを GlobalAvgPool2d クラスで実装してみましょう。下記がコードになります。
import torch
import torch.nn as nn
import torch.nn.functional as F
class GlobalAvgPool2d(nn.Module):
def __init__(self,
device='cpu'):
super().__init__()
def forward(self, x):
return F.avg_pool2d(x, kernel_size=x.size()[2:]).view(-1, x.size(1))