PyTorchによるCNNを用いた医療画像2クラス分類【EfficientNet】

スポンサーリンク

今回の記事では、CNN(EfficientNet)を使って、画像分類をやってみたいと思います。

実装環境の準備(Google Colaboratory)

この記事では、Google Colaboratoryの環境を用いて実装していきます。

GPUが使えて、環境構築が簡単なため幅広く使用されています。

データセットの準備

今回使用するデータセットは、下記サイトから取ってきたものになります。

http://imgcom.jsrt.or.jp/minijsrtdb/

日本放射線技術学会 胸部X線画像の性別の2クラス分類データセット Gender01

こちらのデータセットは、胸部正面像が性別ごとに分けられており、どの性別かを予測する事に使用します。

▼ データセットはこちらの下記リンクからダウンロードしてください。

データセットのダウンロードはこちら

Google colaboratoryを前提に実装していきますが、ローカルのJupyter環境でも問題ありません。

ダウンロードが出来たらGoogleDriveにzipファイルのままアップロードしておいてください。

female:女性の胸部X線画像
male:男性の胸部X線画像

データセットの構造は、下図のようになっています。

trainには学習用データセット、validationには検証用のデータセット、testにはテスト用データセットが含まれています。

一般的に、このようにtrain、validation、testの3つに分けておくと都合がよいです。

医療画像を用いたEfficientNetによる2クラス分類

それでは、これから実際にEfficientNetを使った学習を行うためのコードを書いていきます。

ライブラリのインポート

まずは自分のGoogleDriveにアップロードしたデータを使用するために、GoogleDriveと連携させる(Googleドライブのマウント)事を行います。

下記コードを実行する事で、Googleドライブのマウントは完了です。

# Googleドライブのマウント

from google.colab import drive
drive.mount('/content/drive')

続いて、先程アップロードしたzipファイルの解凍を行います。

解凍してからアップロードしてもいいですが、データ量が多い場合には時間がかかるので、zipファイルでアップロードして、コードで解凍する方がおすすめです。

from zipfile import ZipFile

# Zipファイルの解凍

file_name = '/content/drive/My Drive/Gender01.zip'
with ZipFile(file_name, 'r') as zip:
    zip.extractall()

次に、必要なライブラリをインポートしておきます。

from __future__ import print_function

import glob
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

from pathlib import Path
import seaborn as sns
import timm
from pprint import pprint

import copy
from tqdm import tqdm

学習条件の設定をしていきます。条件を変更したい場合は、ここの数値を適宜変更してください。

# Training settings
epochs = 50
lr = 3e-5
gamma = 0.7
seed = 42

シードの設定をしていきます。ここは、特に変更せずそのままで大丈夫です。

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

学習用データセットの設定

cudaの設定と、データセットの保存してあるパスを指定します。

device = 'cuda'
train_dataset_dir = Path('/content/Gender01/train')
val_dataset_dir = Path('/content/Gender01/validation')
test_dataset_dir = Path('/content/Gender01/test')

データセットの中身を確認します。

画像を表示してみて、データセットがうまく読み込めているか確認します。

files = glob.glob('/content/Gender01/*/*/*.png')
random_idx = np.random.randint(1, len(files), size=9)
fig, axes = plt.subplots(3, 3, figsize=(8, 6))

for idx, ax in enumerate(axes.ravel()):
    img = Image.open(files[idx])
    ax.imshow(img)

transformの設定

transformを使って、データセットの画像の前処理を行います。

train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

test_transforms = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

・画像のサイズを224x224にリサイズ
・左右反転によるData Augmentation
・Tensor型へデータ変更
・正規化

データセットのロード

元のフォルダを読み込んで、データセット(画像とラベルのセット)を作成します。

train_data = datasets.ImageFolder(train_dataset_dir,train_transforms)
valid_data = datasets.ImageFolder(val_dataset_dir, val_transforms)
test_data = datasets.ImageFolder(test_dataset_dir, test_transforms)

データをバッチに分けます。

今回は、batch_sizeを16にしていますが、GPUメモリが不足している場合は減らしましょう。

これで学習データセットのロードは完了です。

train_loader = DataLoader(dataset = train_data, batch_size=16, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=16, shuffle=False)
test_loader = DataLoader(dataset = test_data, batch_size=16, shuffle=False)

EfficientNetのモデルをロード

timmを使ってEfficientNetのモデルをダウンロードし、さらに転移学習を実行していきます。

model_names = timm.list_models(pretrained=True)
pprint(model_names)

['adv_inception_v3', 'bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_224_in22k', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_224_in22k', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', 'beitv2_base_patch16_224_in22k', 'beitv2_large_patch16_224', 'beitv2_large_patch16_224_in22k',・・・

コードを実行すると、ロードできるモデル一覧が表示されます。

今回は‘tf_efficientnetv2_s_in21ft1k’を使用します。

timmからpretrained modelをダウンロードします。

model = timm.create_model('tf_efficientnetv2_s_in21ft1k', pretrained=True, num_classes=2)
model = model.to(device)

pretrain=Trueを指定することで、重みの学習が行われます。

num_classesは分類するクラス数をしていするので、今回は「female,male」の2クラスを指定します。

EfficientNetによる学習

損失関数、活性化関数の設定をします。

今回はクロスエントロピーとAdamを使用しますが自由に設定できます。

# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

学習ループの設定、実行をしていきます。

また、最も良いモデルの重みの保存として、validationのlossが最も小さくなるような重みを保存します。

best_loss = None

# Accuracy計算用の関数
def calculate_accuracy(output, target):
    output = (torch.sigmoid(output) >= 0.5)
    target = (target == 1.0)
    accuracy = torch.true_divide((target == output).sum(dim=0), output.size(0)).item()
    return accuracy

train_acc_list = []
val_acc_list = []
train_loss_list = []
val_loss_list = []

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)              

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

    train_acc_list.append(epoch_accuracy)
    val_acc_list.append(epoch_val_accuracy)
    train_loss_list.append(epoch_loss)
    val_loss_list.append(epoch_val_loss)

    if (best_loss is None) or (best_loss > val_loss):
        best_loss = val_loss
        model_path = '/content/drive/MyDrive/bestViTmodel.pth'
        torch.save(model.state_dict(), model_path)
        
    print()

このような感じで学習が進んでいく様子が分かります。

学習結果の可視化

学習結果として、Accuracyとlossの曲線を出力していきます。

device2 = torch.device('cpu')

train_acc = []
train_loss = []
val_acc = []
val_loss = []

for i in range(epochs):
    train_acc2 = train_acc_list[i].to(device2)
    train_acc3 = train_acc2.clone().numpy()
    train_acc.append(train_acc3)
    
    train_loss2 = train_loss_list[i].to(device2)
    train_loss3 = train_loss2.clone().detach().numpy()
    train_loss.append(train_loss3)
    
    val_acc2 = val_acc_list[i].to(device2)
    val_acc3 = val_acc2.clone().numpy()
    val_acc.append(val_acc3)
    
    val_loss2 = val_loss_list[i].to(device2)
    val_loss3 = val_loss2.clone().numpy()
    val_loss.append(val_loss3)

#取得したデータをグラフ化する
sns.set()
num_epochs = epochs

fig = plt.subplots(figsize=(12, 4), dpi=80)

ax1 = plt.subplot(1,2,1)
ax1.plot(range(num_epochs), train_acc, c='b', label='train acc')
ax1.plot(range(num_epochs), val_acc, c='r', label='val acc')
ax1.set_xlabel('epoch', fontsize='12')
ax1.set_ylabel('accuracy', fontsize='12')
ax1.set_title('training and val acc', fontsize='14')
ax1.legend(fontsize='12')

ax2 = plt.subplot(1,2,2)
ax2.plot(range(num_epochs), train_loss, c='b', label='train loss')
ax2.plot(range(num_epochs), val_loss, c='r', label='val loss')
ax2.set_xlabel('epoch', fontsize='12')
ax2.set_ylabel('loss', fontsize='12')
ax2.set_title('training and val loss', fontsize='14')
ax2.legend(fontsize='12')
plt.show()

・Validation accuracy:0.5000
・Validation loss:3.6753

と言った結果になりました。

ここまでで、学習のステップは完了です。

テストデータによる検証

学習した結果を元に、保存したモデルの重みを用いて、テストデータで検証をします。

汎化性能の有無を検証します。

model.eval()  # モデルを評価モードにする

loss_sum = 0
correct = 0

with torch.no_grad():
    for data, labels in test_loader:

        # GPUが使えるならGPUにデータを送る
        data = data.to(device)
        labels = labels.to(device)

        # ニューラルネットワークの処理を実施
        outputs = model(data)

        # 損失(出力とラベルとの誤差)の計算
        loss_sum += criterion(outputs, labels)

        # 正解の値を取得
        pred = outputs.argmax(1)
        # 正解数をカウント
        correct += pred.eq(labels.view_as(pred)).sum().item()

print(f"Loss: {loss_sum.item() / len(test_loader)}, Accuracy: {100*correct/len(test_data)}% ({correct}/{len(test_data)})")

Loss:2.1673596700032554 , Accuracy:57.4468085106383% (27/47)

まとめ

今回の記事では、PyTorchによるCNNを用いた医療画像2クラス分類を実装してみました。

今回は、特に精度を上げるための工夫などはしていないので、ぜひDataAugmentationなどで精度が上がる方法を探してみてくださいね!

もしもこの記事が気に入ったら上のハートのいいね!ボタンを押して頂けると嬉しいです!今後の励みになります!(ちなみに、押すとボタン周辺にハートが出てきて少し幸せな気持ちになれます笑)

スポンサーリンク

関連コンテンツ

  • この記事を書いた人

Ryusei

【経歴】診療放射線技師免許取得 ▶︎ 大学院生 ▶︎ AI Frontier 運営 ▶︎ AIをメインに情報発信します ▶︎ 23歳

-programming