プログラミング

2020/03/13

ResNet (Residual Network) の実装

ResNet (Residual Network) の実装

画像認識タスクにおいて、高い予測性能をもつ ResNet。ImageNetのSOTAランキングでも、EfficientNetと並び、応用モデルが上位にランクインしています。ライブラリ等を用いれば事前学習済のResNetは簡単に読み込めますが、モデルの構造をきちんと実装しようとすると、どうなるでしょうか?今回は、このResNetをPyTorchを用いて実装していきたいと思います。

様々な応用モデルが存在するResNetですが、もともとは2015年に Deep Residual Learning for Image Recognition という論文で提案された手法になります。大きな特長は、skip connection あるいは residual connection と呼ばれる、層を飛び越えた結合を持つことです。

下図がモデルの概要図です(図は元論文より引用)。

ResNetのモデル概要
  • 【画像をクリックすると拡大します】
引用元:Deep Residual Learning for Image Recognition

図の右に飛び出ている結合が skip connection になります。また、この図に “34-layer residual” とあるように、ResNetは層の深さが異なる複数のモデルがあり、論文内では18層、34層、50層、101層、152層のモデル構造が提案されています。

それぞれの構造は下表の通りです(表は元論文より引用)。

 
ResNetのモデル構造
  • 【画像をクリックすると拡大します】
引用元:Deep Residual Learning for Image Recognition

今回は、このうち50層のモデル (ResNet50) を実装してみましょう。上表にある [ ] で囲まれた部分は building blocks と呼ばれるモジュールで、グラフィカル表現にしたものが下図になります(図は元論文より引用)。

 
Building block 構造
  • 【画像をクリックすると拡大します】
引用元:Deep Residual Learning for Image Recognition

1×1 あるいは 3×3 のカーネルサイズを持つ畳み込み層を積み重ねた構造で、チャネル数は出力層に近づくにつれ大きくなっていきます。

まずは、このブロック構造を実装してみましょう。下記がコードになります。ResNet50(以降のモデル)では、skip connection を行うために最後にチャネル数を調整しなくてはならない場合がある(すなわち、ブロックの入力のチャネル数を、出力のチャネル数に合わせる必要がある)ので、これを shortcut として実装しています。中身はシンプルで、1×1 の畳み込みを差し込むことによるチャネル数の調整になります。

import pytorch.nn as nn


class Block(nn.Module):
    def __init__(self, channel_in, channel_out):
        super().__init__()
        channel = channel_out // 4

        # 1x1 の畳み込み
        self.conv1 = nn.Conv2d(channel_in, channel,
                               kernel_size=(1, 1))
        self.bn1 = nn.BatchNorm2d(channel)
        self.relu1 = nn.ReLU()

        # 3x3 の畳み込み
        self.conv2 = nn.Conv2d(channel, channel,
                               kernel_size=(3, 3),
                               padding=1)
        self.bn2 = nn.BatchNorm2d(channel)
        self.relu2 = nn.ReLU()

        # 1x1 の畳み込み
        self.conv3 = nn.Conv2d(channel, channel_out,
                               kernel_size=(1, 1),
                               padding=0)
        self.bn3 = nn.BatchNorm2d(channel_out)

        # skip connection用のチャネル数調整        
        self.shortcut = self._shortcut(channel_in, channel_out)
        
        self.relu3 = nn.ReLU()

    def forward(self, x):
        h = self.conv1(x)
        h = self.bn1(h)
        h = self.relu1(h)
        h = self.conv2(h)
        h = self.bn2(h)
        h = self.relu2(h)
        h = self.conv3(h)
        h = self.bn3(h)
        shortcut = self.shortcut(x)
        y = self.relu3(h + shortcut)  # skip connection
        return y

    def _shortcut(self, channel_in, channel_out):
        if channel_in != channel_out:
            return self._projection(channel_in, channel_out)
        else:
            return lambda x: x

    def _projection(self, channel_in, channel_out):
        return nn.Conv2d(channel_in, channel_out,
                         kernel_size=(1, 1),
                         padding=0)

この Block クラスを利用して、ResNet50 クラスの実装をしてみましょう。ResNet50はブロック構造が4つ並んでいるだけとも言えるので、コード自体は非常にシンプルにまとまります。下記が実装の内容です(ただし、実装中に出てくる GlobalAvgPool2d に関しては後述します)。

import torch
import torch.nn as nn


class ResNet50(nn.Module):
    def __init__(self, output_dim):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 64,
                               kernel_size=(7, 7),
                               stride=(2, 2),
                               padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=(3, 3),
                                  stride=(2, 2),
                                  padding=1)

        # Block 1
        self.block0 = self._building_block(256, channel_in=64)
        self.block1 = nn.ModuleList([
            self._building_block(256) for _ in range(2)
        ])

        self.conv2 = nn.Conv2d(256, 512,
                               kernel_size=(1, 1),
                               stride=(2, 2))

        # Block 2
        self.block2 = nn.ModuleList([
            self._building_block(512) for _ in range(4)
        ])

        self.conv3 = nn.Conv2d(512, 1024,
                               kernel_size=(1, 1),
                               stride=(2, 2))

        # Block 3
        self.block3 = nn.ModuleList([
            self._building_block(1024) for _ in range(6)
        ])

        self.conv4 = nn.Conv2d(1024, 2048,
                               kernel_size=(1, 1),
                               stride=(2, 2))

        # Block 4
        self.block4 = nn.ModuleList([
            self._building_block(2048) for _ in range(3)
        ])

        self.avg_pool = GlobalAvgPool2d()  # TODO: GlobalAvgPool2d
        self.fc = nn.Linear(2048, 1000)
        self.out = nn.Linear(1000, output_dim)

    def forward(self, x):
        h = self.conv1(x)
        h = self.bn1(h)
        h = self.relu1(h)
        h = self.pool1(h)
        h = self.block0(h)
        for block in self.block1:
            h = block(h)
        h = self.conv2(h)
        for block in self.block2:
            h = block(h)
        h = self.conv3(h)
        for block in self.block3:
            h = block(h)
        h = self.conv4(h)
        for block in self.block4:
            h = block(h)
        h = self.avg_pool(h)
        h = self.fc(h)
        h = torch.relu(h)
        h = self.out(h)
        y = torch.log_softmax(h, dim=-1)

        return y

    def _building_block(self,
                        channel_out,
                        channel_in=None):
        if channel_in is None:
            channel_in = channel_out
        return Block(channel_in, channel_out)

ResNet50では、最後の全結合層に接続する際に、global average poolingを行います。これを GlobalAvgPool2d クラスで実装してみましょう。下記がコードになります。

import torch
import torch.nn as nn
import torch.nn.functional as F


class GlobalAvgPool2d(nn.Module):
    def __init__(self,
                 device='cpu'):
        super().__init__()

    def forward(self, x):
        return F.avg_pool2d(x, kernel_size=x.size()[2:]).view(-1, x.size(1))

以上で必要なレイヤーがすべて実装できました。それでは、Fashion MNIST を対象として、実際に学習・評価をしてみましょう。全体の実装は次の通りです(学習に時間がかかりますので、エポック数は 5 にしています)。

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optimizers
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import accuracy_score


class ResNet50(nn.Module):
    # 省略


class Block(nn.Module):
    # 省略


class GlobalAvgPool2d(nn.Module):
    # 省略


if __name__ == '__main__':
    np.random.seed(1234)
    torch.manual_seed(1234)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    '''
    データの読み込み
    '''
    root = os.path.join(os.path.dirname(__file__),
                        '..', 'data', 'fashion_mnist')
    transform = transforms.Compose([transforms.ToTensor()])
    mnist_train = \
        torchvision.datasets.FashionMNIST(root=root,
                                          download=True,
                                          train=True,
                                          transform=transform)
    mnist_test = \
        torchvision.datasets.FashionMNIST(root=root,
                                          download=True,
                                          train=False,
                                          transform=transform)

    train_dataloader = DataLoader(mnist_train,
                                  batch_size=100,
                                  shuffle=True)
    test_dataloader = DataLoader(mnist_test,
                                 batch_size=100,
                                 shuffle=False)

    '''
    モデルの構築
    '''
    model = ResNet50(10).to(device)

    '''
    モデルの学習・評価
    '''
    def compute_loss(label, pred):
        return criterion(pred, label)

    def train_step(x, t):
        model.train()
        preds = model(x)
        loss = compute_loss(t, preds)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss, preds

    def test_step(x, t):
        model.eval()
        preds = model(x)
        loss = compute_loss(t, preds)

        return loss, preds

    criterion = nn.NLLLoss()
    optimizer = optimizers.Adam(model.parameters(), weight_decay=0.01)
    epochs = 5

    for epoch in range(epochs):
        train_loss = 0.
        test_loss = 0.
        test_acc = 0.

        for (x, t) in train_dataloader:
            x, t = x.to(device), t.to(device)
            loss, _ = train_step(x, t)
            train_loss += loss.item()

        train_loss /= len(train_dataloader)

        for (x, t) in test_dataloader:
            x, t = x.to(device), t.to(device)
            loss, preds = test_step(x, t)
            test_loss += loss.item()
            test_acc += \
                accuracy_score(t.tolist(), preds.argmax(dim=-1).tolist())

        test_loss /= len(test_dataloader)
        test_acc /= len(test_dataloader)
        print('Epoch: {}, Valid Cost: {:.3f}, Valid Acc: {:.3f}'.format(
            epoch+1,
            test_loss,
            test_acc
        ))

以上を実行すると、次のような結果が得られ、きちんとResNetで学習ができていることが確認できます。

Epoch: 1, Valid Cost: 0.469, Valid Acc: 0.839
Epoch: 2, Valid Cost: 0.451, Valid Acc: 0.850
Epoch: 3, Valid Cost: 0.418, Valid Acc: 0.858
Epoch: 4, Valid Cost: 0.448, Valid Acc: 0.850
Epoch: 5, Valid Cost: 0.365, Valid Acc: 0.877

今回は、CNNのモデルのひとつであるResNetの実装を紹介しました。Skip connection は実装もシンプルであるにもかかわらず、高い予測性能につながりますので、しっかり内容について理解しておきましょう!

関連記事Related Posts