PyTorchによるViTを用いた医療画像2クラス分類【Vision Transformer】
Vision Transformer(通称:ViT)は、自然言語処理の分野のtransformerを画像認識の分野に応用したモデルです。
ViTは、CNNにとって代わる革新的な技術として、注目されています。
今回は、胸部X線の性別分類のデータを使ってViTモデルの実装を行ってみようと思います。
実際に、ViTを試してみてCNNより良い精度が出るのか見ていきたいと思います。
実装環境の準備(Google Colaboratory)
この記事では、Google Colaboratoryの環境を用いて実装していきます。
GPUが使えて、環境構築が簡単なため幅広く使用されています。
また、コードについては、timmのViTを使用しています。
データセットの準備
今回使用するデータセットは、下記サイトから取ってきたものになります。
http://imgcom.jsrt.or.jp/minijsrtdb/
日本放射線技術学会 胸部X線画像の性別の2クラス分類データセット Gender01
こちらのデータセットは、胸部正面像が性別ごとに分けられており、どの性別かを予測する事に使用します。
▼ データセットはこちらの下記リンクからダウンロードしてください。
Google colaboratoryを前提に実装していきますが、ローカルのJupyter等の環境でも問題ありません。
ダウンロードが出来たらGoogleDriveにzipファイルのままアップロードしておいてください。
female:女性の胸部X線画像
male:男性の胸部X線画像
データセットの構造は、下図のようになっています。
trainには学習用データセット、validationには検証用のデータセット、testにはテスト用データセットが含まれています。
一般的に、このようにtrain、validation、testの3つに分けておくと都合がよいです。
医療画像を用いたViTによる2クラス分類
それでは、これから実際にViTを使った学習を行うためのコードを書いていきます。
前準備
まず、モデルを使用するために、timmをGoogleColaboratory上にインストールします。
!pip install timm
ここまで出来たら前準備は完了です。
ライブラリのインポート
まずは自分のGoogleDriveにアップロードしたデータを使用するために、GoogleDriveと連携させる(Googleドライブのマウント)事を行います。
下記コードを実行する事で、Googleドライブのマウントは完了です。
# Googleドライブのマウント
from google.colab import drive
drive.mount('/content/drive')
続いて、先程アップロードしたzipファイルの解凍を行います。
解凍してからアップロードしてもいいですが、データ量が多い場合には時間がかかるので、zipファイルでアップロードして、コードで解凍する方がおすすめです。
from zipfile import ZipFile
# Zipファイルの解凍
file_name = '/content/drive/My Drive/Gender01.zip'
with ZipFile(file_name, 'r') as zip:
zip.extractall()
次に、必要なライブラリをインポートしておきます。
from __future__ import print_function
import glob
import os
import random
import cv2
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.transforms import v2
from tqdm.notebook import tqdm
from pathlib import Path
import seaborn as sns
import timm
from pprint import pprint
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score, roc_auc_score, roc_curve, auc
import matplotlib.pyplot as plt
学習条件の設定をしていきます。条件を変更したい場合は、ここの数値を適宜変更してください。
# Training settings
epochs = 100
lr = 0.001
学習用データセットの設定
cudaの設定と、データセットの保存してあるパスを指定します。
device = 'cuda'
train_dataset_dir = Path('/content/Gender01/train')
val_dataset_dir = Path('/content/Gender01/validation')
test_dataset_dir = Path('/content/Gender01/test')
データセットの中身を確認します。
画像を表示してみて、データセットがうまく読み込めているか確認します。
files = glob.glob('/content/Gender01/*/*/*.png')
random_idx = np.random.randint(1, len(files), size=9)
fig, axes = plt.subplots(3, 3, figsize=(8, 6))
for idx, ax in enumerate(axes.ravel()):
img = Image.open(files[idx])
ax.imshow(img)
transformの設定
transformを使って、データセットの画像の前処理を行います。
train_transforms = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
)
val_transforms = transforms.Compose(
[
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
)
test_transforms = transforms.Compose(
[
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
)
画像のサイズを224×224にリサイズ 左右反転によるData Augmentation Tensor型へデータ変更 正規化
データセットのロード
元のフォルダを読み込んで、データセット(画像とラベルのセット)を作成します。
train_data = datasets.ImageFolder(train_dataset_dir,train_transforms)
valid_data = datasets.ImageFolder(val_dataset_dir, val_transforms)
test_data = datasets.ImageFolder(test_dataset_dir, test_transforms)
データをバッチに分けます。
今回は、batch_sizeを16にしていますが、GPUメモリが不足している場合は減らしましょう。
これで学習データセットのロードは完了です。
train_loader = DataLoader(dataset = train_data, batch_size=16, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=16, shuffle=False)
test_loader = DataLoader(dataset = test_data, batch_size=16, shuffle=False)
ViTのモデルをロード
timmを使ってViTのモデルをダウンロードし、さらに転移学習を実行していきます。
model_names = timm.list_models(pretrained=True)
pprint(model_names)
コードを実行すると、上記のようにロードできるモデル一覧が表示されます。
今回は‘vit_base_patch16_224_in21k’を使用します。
timmからpretrained modelをダウンロードします。
model = timm.create_model('vit_base_patch16_224_in21k', pretrained=True, num_classes=2)
model = model.to(device)
pretrain=Trueを指定することで、重みの学習が行われます。
num_classesは分類するクラス数をしていするので、今回は「female,male」の2クラスを指定します。
ViTによる学習
損失関数、活性化関数の設定をします。
今回はクロスエントロピーとAdamを使用しますが自由に設定できます。
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
学習ループの設定、実行をしていきます。
また、最も良いモデルの重みの保存として、validationのlossが最も小さくなるような重みを保存します。
best_loss = None
# Accuracy計算用の関数
def calculate_accuracy(output, target):
output = (torch.sigmoid(output) >= 0.5)
target = (target == 1.0)
accuracy = torch.true_divide((target == output).sum(dim=0), output.size(0)).item()
return accuracy
train_acc_list = []
val_acc_list = []
train_loss_list = []
val_loss_list = []
for epoch in range(epochs):
epoch_loss = 0
epoch_accuracy = 0
for data, label in tqdm(train_loader):
data = data.to(device)
label = label.to(device)
output = model(data)
loss = criterion(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (output.argmax(dim=1) == label).float().mean()
epoch_accuracy += acc / len(train_loader)
epoch_loss += loss / len(train_loader)
with torch.no_grad():
epoch_val_accuracy = 0
epoch_val_loss = 0
for data, label in valid_loader:
data = data.to(device)
label = label.to(device)
val_output = model(data)
val_loss = criterion(val_output, label)
acc = (val_output.argmax(dim=1) == label).float().mean()
epoch_val_accuracy += acc / len(valid_loader)
epoch_val_loss += val_loss / len(valid_loader)
print(
f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
)
train_acc_list.append(epoch_accuracy)
val_acc_list.append(epoch_val_accuracy)
train_loss_list.append(epoch_loss)
val_loss_list.append(epoch_val_loss)
if (best_loss is None) or (best_loss > val_loss):
best_loss = val_loss
model_path = '/content/drive/MyDrive/bestViTmodel.pth'
torch.save(model.state_dict(), model_path)
print()
このような感じで学習が進んでいく様子が分かります。
学習結果の可視化
学習結果として、Accuracyとlossの曲線を出力していきます。
device2 = torch.device('cpu')
train_acc = []
train_loss = []
val_acc = []
val_loss = []
for i in range(epochs):
train_acc2 = train_acc_list[i].to(device2)
train_acc3 = train_acc2.clone().numpy()
train_acc.append(train_acc3)
train_loss2 = train_loss_list[i].to(device2)
train_loss3 = train_loss2.clone().detach().numpy()
train_loss.append(train_loss3)
val_acc2 = val_acc_list[i].to(device2)
val_acc3 = val_acc2.clone().numpy()
val_acc.append(val_acc3)
val_loss2 = val_loss_list[i].to(device2)
val_loss3 = val_loss2.clone().numpy()
val_loss.append(val_loss3)
#取得したデータをグラフ化する
sns.set()
num_epochs = epochs
fig = plt.subplots(figsize=(12, 4), dpi=80)
ax1 = plt.subplot(1,2,1)
ax1.plot(range(num_epochs), train_acc, c='b', label='train acc')
ax1.plot(range(num_epochs), val_acc, c='r', label='val acc')
ax1.set_xlabel('epoch', fontsize='12')
ax1.set_ylabel('accuracy', fontsize='12')
ax1.set_title('training and val acc', fontsize='14')
ax1.legend(fontsize='12')
ax2 = plt.subplot(1,2,2)
ax2.plot(range(num_epochs), train_loss, c='b', label='train loss')
ax2.plot(range(num_epochs), val_loss, c='r', label='val loss')
ax2.set_xlabel('epoch', fontsize='12')
ax2.set_ylabel('loss', fontsize='12')
ax2.set_title('training and val loss', fontsize='14')
ax2.legend(fontsize='12')
plt.show()
ここまでで、学習のステップは完了です。
テストデータによる検証
学習した結果を元に、保存したモデルの重みを用いて、テストデータで検証をします。
汎化性能の有無を検証します。
def evaluate_model(model, test_loader):
model.load_state_dict(torch.load("/content/drive/MyDrive/bestViTmodel.pth"))
model.eval()
predictions = []
actuals = []
probas = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
probas.extend(outputs[:, 1].cpu().numpy()) # 二番目のクラス(陽性クラス)の確率
predictions.extend(predicted.cpu().numpy())
actuals.extend(labels.cpu().numpy())
# 指標の計算
confusion = confusion_matrix(actuals, predictions)
accuracy = accuracy_score(actuals, predictions)
recall = recall_score(actuals, predictions) # 感度
specificity = recall_score(actuals, predictions, pos_label=0)
precision = precision_score(actuals, predictions) # PPV
npv = precision_score(actuals, predictions, pos_label=0) # NPV
f1 = f1_score(actuals, predictions)
auc = roc_auc_score(actuals, probas)
#print("混同行列:\n", confusion)
print(f"Accuracy: {accuracy:.4f}")
print(f"感度: {recall:.4f}")
print(f"特異度: {specificity:.4f}")
print(f"PPV: {precision:.4f}")
print(f"NPV: {npv:.4f}")
print(f"F-Score: {f1:.4f}")
print(f"AUC: {auc:.4f}")
# Confusion matrixの描画
# 混同行列の可視化
plt.figure(figsize=(7, 6))
sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix")
plt.ylabel("True")
plt.xlabel("Predicted")
plt.show()
# ROC曲線の描画
fpr, tpr, thresholds = roc_curve(actuals, probas)
plt.figure(figsize=(6,6))
plt.plot(fpr, tpr, color='blue', label='ROC curve (AUC = %0.2f)' % auc)
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.show()
evaluate_model(model, test_loader)
Accuracy: 0.8936
感度: 0.8462
特異度: 0.9524
PPV: 0.9565
NPV: 0.8333
F-Score: 0.8980
AUC: 0.9359
ViT vs CNNの比較
結局、画像認識においてViTとCNNどっちがいいの?という議論もされています。
そこで、今回はViTのtestの精度とEfficientNetのtestの精度を比較します。
EfficientNetの実装はこちら
EfficientNet | VisionTransformer | |
---|---|---|
Accuracy | 0.9149 | 0.8936 |
感度 | 0.9231 | 0.8462 |
特異度 | 0.9048 | 0.9524 |
PPV | 0.9231 | 0.9565 |
NPV | 0.9048 | 0.8333 |
F-Score | 0.9231 | 0.8980 |
AUC | 0.9634 | 0.9359 |
今回の、単純なデータセットでは、CNNで高い精度が出たので、ViTが上回る結果にはなりませんでした。
しかし、様々な工夫により、精度が上がるパターンがあるかもしれないので、ぜひ試してみてください。
まとめ
今回の記事では、PyTorchによるViTを用いた医療画像2クラス分類を実装してみました。
CNNからViTへ変わる可能性が示唆されており、様々な研究も増えています。
ぜひ、トレンドを抑えて実装して実験の参考にしてみてください。
コメント
コメント一覧 (6件)
CUDAを使わない場合での質問をしましたが、Google Colaboratoryを使い始めるのは簡単なのがわかったので、それで動かし、自分で調べることにしました。
遅れてすみません!GoogleColaboratoryだとGPU使えるので、そのままでも大丈夫ですが。使わない場合だと、cpuを使えば可能です。時間はその分かかります。
validation ディレクトリに関する質問に答えていただけると嬉しいです。5つのファイルとtestディレクトリからコピーして、動くようになりましたが、これでいいのでしょうか?
「5つのファイルとtestディレクトリからコピーして、」という状況が上手く読み取れず、申し訳ご合いません。この記事で紹介しているようなディレクトリの構造になっていて、動いていれば特に問題はないかと思います。もし宜しければ、記事上部のリンクからデータセットをダウンロードして頂けると良いかと思います。
腰、背骨の辺りのX線写真のデータセットはありませんか?
腰痛があり、整形外科でX線写真を撮り診断を受けると同時に画像をもらいました。その画像から診断が正しいか検証しようと思います。
腰や、背骨あたりのデータセットはあまり見かけないかもしれません。もしかしたら、kaggleなどで調べると出てくるかもしれません。