PyTorchによるSwinTransformerを用いた医療画像2クラス分類【SwinTransformer】

スポンサーリンク

SwinTransformerは、VisonTransformer(通称:ViT)の性能を改良したモデルです。

SwinTransformerは、計算コストを削減して性能を向上させる革新的な技術として、注目されています。

今回は、胸部X線の性別分類のデータを使ってSwinTransformerモデルの実装を行ってみようと思います。

実際に、SwinTransformerを試してみてViTやCNNより良い精度が出るのか見ていきたいと思います。

実装環境の準備(Google Colaboratory)

この記事では、Google Colaboratoryの環境を用いて実装していきます。

GPUが使えて、環境構築が簡単なため幅広く使用されています。

また、コードについては、PyTorchで実装したvit-pytorchを使用しています。

データセットの準備

今回使用するデータセットは、下記サイトから取ってきたものになります。

http://imgcom.jsrt.or.jp/minijsrtdb/

日本放射線技術学会 胸部X線画像の性別の2クラス分類データセット Gender01

こちらのデータセットは、胸部正面像が性別ごとに分けられており、どの性別かを予測する事に使用します。

▼ データセットはこちらの下記リンクからダウンロードしてください。

データセットのダウンロードはこちら

Google colaboratoryを前提に実装していきますが、ローカルのJupyter環境でも問題ありません。

ダウンロードが出来たらGoogleDriveにzipファイルのままアップロードしておいてください。

female:女性の胸部X線画像
male:男性の胸部X線画像

データセットの構造は、下図のようになっています。

trainには学習用データセット、validationには検証用のデータセット、testにはテスト用データセットが含まれています。

一般的に、このようにtrain、validation、testの3つに分けておくと都合がよいです。

医療画像を用いたSwinTransformerによる2クラス分類

それでは、これから実際にSwinTransformerを使った学習を行うためのコードを書いていきます。

前準備

SwinTransformerをPyTorchで実装したvit-pytorchを使うために、git cloneを行います。

!git clone https://github.com/lucidrains/vit-pytorch.git

git cloneが出来たら、GoogleColaboratory上にインストールします。

!pip install vit_pytorch
!pip install timm

ここまで出来たら前準備は完了です。

ライブラリのインポート

まずは自分のGoogleDriveにアップロードしたデータを使用するために、GoogleDriveと連携させる(Googleドライブのマウント)事を行います。

下記コードを実行する事で、Googleドライブのマウントは完了です。

# Googleドライブのマウント

from google.colab import drive
drive.mount('/content/drive')

続いて、先程アップロードしたzipファイルの解凍を行います。

解凍してからアップロードしてもいいですが、データ量が多い場合には時間がかかるので、zipファイルでアップロードして、コードで解凍する方がおすすめです。

from zipfile import ZipFile

# Zipファイルの解凍

file_name = '/content/drive/My Drive/Gender01.zip'
with ZipFile(file_name, 'r') as zip:
    zip.extractall()

次に、必要なライブラリをインポートしておきます。

from __future__ import print_function

import glob
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

from vit_pytorch.efficient import ViT
from pathlib import Path
import seaborn as sns
import timm
from pprint import pprint

import copy
from tqdm import tqdm

学習条件の設定をしていきます。条件を変更したい場合は、ここの数値を適宜変更してください。

# Training settings
epochs = 50
lr = 3e-5
gamma = 0.7
seed = 42

シードの設定をしていきます。ここは、特に変更せずそのままで大丈夫です。

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

学習用データセットの設定

cudaの設定と、データセットの保存してあるパスを指定します。

device = 'cuda'
train_dataset_dir = Path('/content/Gender01/train')
val_dataset_dir = Path('/content/Gender01/validation')
test_dataset_dir = Path('/content/Gender01/test')

データセットの中身を確認します。

画像を表示してみて、データセットがうまく読み込めているか確認します。

files = glob.glob('/content/Gender01/*/*/*.png')
random_idx = np.random.randint(1, len(files), size=9)
fig, axes = plt.subplots(3, 3, figsize=(8, 6))

for idx, ax in enumerate(axes.ravel()):
    img = Image.open(files[idx])
    ax.imshow(img)

transformの設定

transformを使って、データセットの画像の前処理を行います。

train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

test_transforms = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

・画像のサイズを224x224にリサイズ
・Tensor型へデータ変更
・正規化

データセットのロード

元のフォルダを読み込んで、データセット(画像とラベルのセット)を作成します。

train_data = datasets.ImageFolder(train_dataset_dir,train_transforms)
valid_data = datasets.ImageFolder(val_dataset_dir, val_transforms)
test_data = datasets.ImageFolder(test_dataset_dir, test_transforms)

データをバッチに分けます。

今回は、batch_sizeを32にしていますが、GPUメモリが不足している場合は減らしましょう。

これで学習データセットのロードは完了です。

train_loader = DataLoader(dataset = train_data, batch_size=32, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=32, shuffle=False)
test_loader = DataLoader(dataset = test_data, batch_size=32, shuffle=False)

SwinTransformerのモデルをロード

timmを使ってSwinTransformerのモデルをダウンロードし、さらに転移学習を実行していきます。

model_names = timm.list_models(pretrained=True)
pprint(model_names)

['adv_inception_v3', 'bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_224_in22k', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_224_in22k', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', 'beitv2_base_patch16_224_in22k', 'beitv2_large_patch16_224', 'beitv2_large_patch16_224_in22k',・・・

コードを実行すると、ロードできるモデル一覧が表示されます。

今回は‘swin_base_patch4_window7_224.ms_in22k’を使用します。

timmからpretrained modelをダウンロードします。

model = timm.create_model('swin_base_patch4_window7_224.ms_in22k', pretrained=True, num_classes=2)
model = model.to(device)

pretrain=Trueを指定することで、重みの学習が行われます。

num_classesは分類するクラス数をしていするので、今回は「female,male」の2クラスを指定します。

SwinTransformerによる学習

損失関数、活性化関数の設定をします。

今回はクロスエントロピーとAdamを使用しますが自由に設定できます。

# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

学習ループの設定、実行をしていきます。

また、最も良いモデルの重みの保存として、validationのlossが最も小さくなるような重みを保存します。

best_loss = None

# Accuracy計算用の関数
def calculate_accuracy(output, target):
    output = (torch.sigmoid(output) >= 0.5)
    target = (target == 1.0)
    accuracy = torch.true_divide((target == output).sum(dim=0), output.size(0)).item()
    return accuracy

train_acc_list = []
val_acc_list = []
train_loss_list = []
val_loss_list = []

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)              

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

    train_acc_list.append(epoch_accuracy)
    val_acc_list.append(epoch_val_accuracy)
    train_loss_list.append(epoch_loss)
    val_loss_list.append(epoch_val_loss)

    if (best_loss is None) or (best_loss > val_loss):
        best_loss = val_loss
        model_path = '/content/drive/MyDrive/bestswinmodel.pth'
        torch.save(model.state_dict(), model_path)
        
    print()

このような感じで学習が進んでいく様子が分かります。

学習結果の可視化

学習結果として、Accuracyとlossの曲線を出力していきます。

device2 = torch.device('cpu')

train_acc = []
train_loss = []
val_acc = []
val_loss = []

for i in range(epochs):
    train_acc2 = train_acc_list[i].to(device2)
    train_acc3 = train_acc2.clone().numpy()
    train_acc.append(train_acc3)
    
    train_loss2 = train_loss_list[i].to(device2)
    train_loss3 = train_loss2.clone().detach().numpy()
    train_loss.append(train_loss3)
    
    val_acc2 = val_acc_list[i].to(device2)
    val_acc3 = val_acc2.clone().numpy()
    val_acc.append(val_acc3)
    
    val_loss2 = val_loss_list[i].to(device2)
    val_loss3 = val_loss2.clone().numpy()
    val_loss.append(val_loss3)

#取得したデータをグラフ化する
sns.set()
num_epochs = epochs

fig = plt.subplots(figsize=(12, 4), dpi=80)

ax1 = plt.subplot(1,2,1)
ax1.plot(range(num_epochs), train_acc, c='b', label='train acc')
ax1.plot(range(num_epochs), val_acc, c='r', label='val acc')
ax1.set_xlabel('epoch', fontsize='12')
ax1.set_ylabel('accuracy', fontsize='12')
ax1.set_title('training and val acc', fontsize='14')
ax1.legend(fontsize='12')

ax2 = plt.subplot(1,2,2)
ax2.plot(range(num_epochs), train_loss, c='b', label='train loss')
ax2.plot(range(num_epochs), val_loss, c='r', label='val loss')
ax2.set_xlabel('epoch', fontsize='12')
ax2.set_ylabel('loss', fontsize='12')
ax2.set_title('training and val loss', fontsize='14')
ax2.legend(fontsize='12')
plt.show()

・Validation accuracy:0.8862
・Validation loss:1.1015

と言った結果になりました。

ここまでで、学習のステップは完了です。

テストデータによる検証

学習した結果を元に、保存したモデルの重みを用いて、テストデータで検証をします。

汎化性能の有無を検証します。

model.load_state_dict(torch.load("/content/drive/MyDrive/bestswinmodel.pth", map_location=device))
model.eval()  # モデルを評価モードにする

loss_sum = 0
correct = 0

with torch.no_grad():
    for data, labels in test_loader:

        # GPUが使えるならGPUにデータを送る
        data = data.to(device)
        labels = labels.to(device)

        # ニューラルネットワークの処理を実施
        outputs = model(data)

        # 損失(出力とラベルとの誤差)の計算
        loss_sum += criterion(outputs, labels)

        # 正解の値を取得
        pred = outputs.argmax(1)
        # 正解数をカウント
        correct += pred.eq(labels.view_as(pred)).sum().item()

print(f"Loss: {loss_sum.item() / len(test_loader)}, Accuracy: {100*correct/len(test_data)}% ({correct}/{len(test_data)})")

Loss: 0.350030779838562, Accuracy: 89.36170212765957% (42/47)

SwinTransformer vs ViT,CNNの比較

結局、画像認識においてSwinTransformerとViTやCNNどっちがいいの?というところも見ていきます。

そこで、今回は SwinTransformerのtestの精度とViT、EfficientNetのtestの精度を比較します。

ViTの実装はこちら、EfficientNetの実装はこちら

SwinTransformerEfficientNet(CNN)VisionTransformer(ViT)
Test Accuracy0.89(42/47)0.57(27/47)0.70(33/47)
Test Loss0.352.170.69

一概に、 SwinTransformerの方が優秀だ!と結論付けることは出来ませんが、今後SwinTransformerによる検証を行ってみることは精度向上の可能性があります。

いろいろと試してみて、精度が上がるパターンを探してみてくださいね!

まとめ

今回の記事では、PyTorchによるSwinTransformerを用いた医療画像2クラス分類を実装してみました。

ViTの登場から様々なTransformerベースのモデルが登場しています。

ぜひ、いろいろなモデルを試してみて、実験の参考にしてみてください。

もしもこの記事が気に入ったら上のハートのいいね!ボタンを押して頂けると嬉しいです!今後の励みになります!(ちなみに、押すとボタン周辺にハートが出てきて少し幸せな気持ちになれます笑)

スポンサーリンク

関連コンテンツ

  • この記事を書いた人

Ryusei

【経歴】診療放射線技師免許取得 ▶︎ 大学院生 ▶︎ AI Frontier 運営 ▶︎ AIをメインに情報発信します ▶︎ 23歳

-programming