SwinTransformerは、VisonTransformer(通称:ViT)の性能を改良したモデルです。
SwinTransformerは、計算コストを削減して性能を向上させる革新的な技術として、注目されています。
今回は、胸部X線の性別分類のデータを使ってSwinTransformerモデルの実装を行ってみようと思います。
実際に、SwinTransformerを試してみてViTやCNNより良い精度が出るのか見ていきたいと思います。
実装環境の準備(Google Colaboratory)
この記事では、Google Colaboratoryの環境を用いて実装していきます。
GPUが使えて、環境構築が簡単なため幅広く使用されています。
また、コードについては、PyTorchで実装したvit-pytorchを使用しています。
データセットの準備
今回使用するデータセットは、下記サイトから取ってきたものになります。
http://imgcom.jsrt.or.jp/minijsrtdb/
日本放射線技術学会 胸部X線画像の性別の2クラス分類データセット Gender01
こちらのデータセットは、胸部正面像が性別ごとに分けられており、どの性別かを予測する事に使用します。
▼ データセットはこちらの下記リンクからダウンロードしてください。
Google colaboratoryを前提に実装していきますが、ローカルのJupyter環境でも問題ありません。
ダウンロードが出来たらGoogleDriveにzipファイルのままアップロードしておいてください。

female:女性の胸部X線画像
male:男性の胸部X線画像
データセットの構造は、下図のようになっています。

trainには学習用データセット、validationには検証用のデータセット、testにはテスト用データセットが含まれています。
一般的に、このようにtrain、validation、testの3つに分けておくと都合がよいです。
医療画像を用いたSwinTransformerによる2クラス分類
それでは、これから実際にSwinTransformerを使った学習を行うためのコードを書いていきます。
前準備
SwinTransformerをPyTorchで実装したvit-pytorchを使うために、git cloneを行います。
!git clone https://github.com/lucidrains/vit-pytorch.git
git cloneが出来たら、GoogleColaboratory上にインストールします。
!pip install vit_pytorch
!pip install timm
ここまで出来たら前準備は完了です。
ライブラリのインポート
まずは自分のGoogleDriveにアップロードしたデータを使用するために、GoogleDriveと連携させる(Googleドライブのマウント)事を行います。
下記コードを実行する事で、Googleドライブのマウントは完了です。
# Googleドライブのマウント
from google.colab import drive
drive.mount('/content/drive')
続いて、先程アップロードしたzipファイルの解凍を行います。
解凍してからアップロードしてもいいですが、データ量が多い場合には時間がかかるので、zipファイルでアップロードして、コードで解凍する方がおすすめです。
from zipfile import ZipFile
# Zipファイルの解凍
file_name = '/content/drive/My Drive/Gender01.zip'
with ZipFile(file_name, 'r') as zip:
zip.extractall()
次に、必要なライブラリをインポートしておきます。
from __future__ import print_function
import glob
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from vit_pytorch.efficient import ViT
from pathlib import Path
import seaborn as sns
import timm
from pprint import pprint
import copy
from tqdm import tqdm
学習条件の設定をしていきます。条件を変更したい場合は、ここの数値を適宜変更してください。
# Training settings
epochs = 50
lr = 3e-5
gamma = 0.7
seed = 42
シードの設定をしていきます。ここは、特に変更せずそのままで大丈夫です。
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
seed_everything(seed)
学習用データセットの設定
cudaの設定と、データセットの保存してあるパスを指定します。
device = 'cuda'
train_dataset_dir = Path('/content/Gender01/train')
val_dataset_dir = Path('/content/Gender01/validation')
test_dataset_dir = Path('/content/Gender01/test')
データセットの中身を確認します。
画像を表示してみて、データセットがうまく読み込めているか確認します。
files = glob.glob('/content/Gender01/*/*/*.png')
random_idx = np.random.randint(1, len(files), size=9)
fig, axes = plt.subplots(3, 3, figsize=(8, 6))
for idx, ax in enumerate(axes.ravel()):
img = Image.open(files[idx])
ax.imshow(img)

transformの設定
transformを使って、データセットの画像の前処理を行います。
train_transforms = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
)
val_transforms = transforms.Compose(
[
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
)
test_transforms = transforms.Compose(
[
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
)
・Tensor型へデータ変更
・正規化
データセットのロード
元のフォルダを読み込んで、データセット(画像とラベルのセット)を作成します。
train_data = datasets.ImageFolder(train_dataset_dir,train_transforms)
valid_data = datasets.ImageFolder(val_dataset_dir, val_transforms)
test_data = datasets.ImageFolder(test_dataset_dir, test_transforms)
データをバッチに分けます。
今回は、batch_sizeを32にしていますが、GPUメモリが不足している場合は減らしましょう。
これで学習データセットのロードは完了です。
train_loader = DataLoader(dataset = train_data, batch_size=32, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=32, shuffle=False)
test_loader = DataLoader(dataset = test_data, batch_size=32, shuffle=False)
SwinTransformerのモデルをロード
timmを使ってSwinTransformerのモデルをダウンロードし、さらに転移学習を実行していきます。
model_names = timm.list_models(pretrained=True)
pprint(model_names)
コードを実行すると、ロードできるモデル一覧が表示されます。
今回は‘swin_base_patch4_window7_224.ms_in22k’を使用します。
timmからpretrained modelをダウンロードします。
model = timm.create_model('swin_base_patch4_window7_224.ms_in22k', pretrained=True, num_classes=2)
model = model.to(device)
pretrain=Trueを指定することで、重みの学習が行われます。
num_classesは分類するクラス数をしていするので、今回は「female,male」の2クラスを指定します。
SwinTransformerによる学習
損失関数、活性化関数の設定をします。
今回はクロスエントロピーとAdamを使用しますが自由に設定できます。
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
学習ループの設定、実行をしていきます。
また、最も良いモデルの重みの保存として、validationのlossが最も小さくなるような重みを保存します。
best_loss = None
# Accuracy計算用の関数
def calculate_accuracy(output, target):
output = (torch.sigmoid(output) >= 0.5)
target = (target == 1.0)
accuracy = torch.true_divide((target == output).sum(dim=0), output.size(0)).item()
return accuracy
train_acc_list = []
val_acc_list = []
train_loss_list = []
val_loss_list = []
for epoch in range(epochs):
epoch_loss = 0
epoch_accuracy = 0
for data, label in tqdm(train_loader):
data = data.to(device)
label = label.to(device)
output = model(data)
loss = criterion(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (output.argmax(dim=1) == label).float().mean()
epoch_accuracy += acc / len(train_loader)
epoch_loss += loss / len(train_loader)
with torch.no_grad():
epoch_val_accuracy = 0
epoch_val_loss = 0
for data, label in valid_loader:
data = data.to(device)
label = label.to(device)
val_output = model(data)
val_loss = criterion(val_output, label)
acc = (val_output.argmax(dim=1) == label).float().mean()
epoch_val_accuracy += acc / len(valid_loader)
epoch_val_loss += val_loss / len(valid_loader)
print(
f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
)
train_acc_list.append(epoch_accuracy)
val_acc_list.append(epoch_val_accuracy)
train_loss_list.append(epoch_loss)
val_loss_list.append(epoch_val_loss)
if (best_loss is None) or (best_loss > val_loss):
best_loss = val_loss
model_path = '/content/drive/MyDrive/bestswinmodel.pth'
torch.save(model.state_dict(), model_path)
print()

このような感じで学習が進んでいく様子が分かります。
学習結果の可視化
学習結果として、Accuracyとlossの曲線を出力していきます。
device2 = torch.device('cpu')
train_acc = []
train_loss = []
val_acc = []
val_loss = []
for i in range(epochs):
train_acc2 = train_acc_list[i].to(device2)
train_acc3 = train_acc2.clone().numpy()
train_acc.append(train_acc3)
train_loss2 = train_loss_list[i].to(device2)
train_loss3 = train_loss2.clone().detach().numpy()
train_loss.append(train_loss3)
val_acc2 = val_acc_list[i].to(device2)
val_acc3 = val_acc2.clone().numpy()
val_acc.append(val_acc3)
val_loss2 = val_loss_list[i].to(device2)
val_loss3 = val_loss2.clone().numpy()
val_loss.append(val_loss3)
#取得したデータをグラフ化する
sns.set()
num_epochs = epochs
fig = plt.subplots(figsize=(12, 4), dpi=80)
ax1 = plt.subplot(1,2,1)
ax1.plot(range(num_epochs), train_acc, c='b', label='train acc')
ax1.plot(range(num_epochs), val_acc, c='r', label='val acc')
ax1.set_xlabel('epoch', fontsize='12')
ax1.set_ylabel('accuracy', fontsize='12')
ax1.set_title('training and val acc', fontsize='14')
ax1.legend(fontsize='12')
ax2 = plt.subplot(1,2,2)
ax2.plot(range(num_epochs), train_loss, c='b', label='train loss')
ax2.plot(range(num_epochs), val_loss, c='r', label='val loss')
ax2.set_xlabel('epoch', fontsize='12')
ax2.set_ylabel('loss', fontsize='12')
ax2.set_title('training and val loss', fontsize='14')
ax2.legend(fontsize='12')
plt.show()

・Validation loss:1.1015
と言った結果になりました。
ここまでで、学習のステップは完了です。
テストデータによる検証
学習した結果を元に、保存したモデルの重みを用いて、テストデータで検証をします。
汎化性能の有無を検証します。
model.load_state_dict(torch.load("/content/drive/MyDrive/bestswinmodel.pth", map_location=device))
model.eval() # モデルを評価モードにする
loss_sum = 0
correct = 0
with torch.no_grad():
for data, labels in test_loader:
# GPUが使えるならGPUにデータを送る
data = data.to(device)
labels = labels.to(device)
# ニューラルネットワークの処理を実施
outputs = model(data)
# 損失(出力とラベルとの誤差)の計算
loss_sum += criterion(outputs, labels)
# 正解の値を取得
pred = outputs.argmax(1)
# 正解数をカウント
correct += pred.eq(labels.view_as(pred)).sum().item()
print(f"Loss: {loss_sum.item() / len(test_loader)}, Accuracy: {100*correct/len(test_data)}% ({correct}/{len(test_data)})")
SwinTransformer vs ViT,CNNの比較
結局、画像認識においてSwinTransformerとViTやCNNどっちがいいの?というところも見ていきます。
そこで、今回は SwinTransformerのtestの精度とViT、EfficientNetのtestの精度を比較します。
ViTの実装はこちら、EfficientNetの実装はこちら
SwinTransformer | EfficientNet(CNN) | VisionTransformer(ViT) | |
Test Accuracy | 0.89(42/47) | 0.57(27/47) | 0.70(33/47) |
Test Loss | 0.35 | 2.17 | 0.69 |
一概に、 SwinTransformerの方が優秀だ!と結論付けることは出来ませんが、今後SwinTransformerによる検証を行ってみることは精度向上の可能性があります。
いろいろと試してみて、精度が上がるパターンを探してみてくださいね!
まとめ
今回の記事では、PyTorchによるSwinTransformerを用いた医療画像2クラス分類を実装してみました。
ViTの登場から様々なTransformerベースのモデルが登場しています。
ぜひ、いろいろなモデルを試してみて、実験の参考にしてみてください。