深度学習モデル精度最適化ガイド:データ前処理から混合精度トレーニングまで#
この記事では、画像セグメンテーションタスク(UNet を例に)において、モデル精度を向上させるための重要な技術を体系的に解説し、データ拡張、モデル最適化、混合精度トレーニングなどをカバーし、直接実行可能なコード例を提供します。
一、なぜモデル精度を最適化する必要があるのか?#
医療画像セグメンテーションや自動運転などのシーンでは、モデル精度がアプリケーションの効果を直接決定します。しかし、実際のトレーニングでは以下のような問題がよく発生します:
- 過学習:モデルはトレーニングセットでは良好に機能するが、検証セットでは悪化する
- 収束が遅い:トレーニングの反復回数が多く、時間がかかる
- メモリ不足:より大きなバッチやより複雑なモデルを使用できない
二、データ前処理:モデル精度の基盤#
1. 基本的な前処理(既存のコード)#
# 画像前処理(比率を保ってリサイズ)
transform_image = transforms.Compose([
transforms.Resize((256,256), InterpolationMode.BILINEAR),
transforms.ToTensor()
])
# ラベル前処理(ピクセル値をクラスインデックスに変換)
transform_mask = transforms.Compose([
transforms.Resize((256,256), InterpolationMode.NEAREST),
transforms.ToTensor(),
lambda x: (x * 255).long().clamp(0, num_classes-1)
])
2. データ拡張の改善案#
問題:元のコードにはデータ拡張が欠けており、モデルの一般化能力が不足している
改善:空間変換と色の変動を追加
transform_image = transforms.Compose([
# 空間変換
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.RandomAffine(degrees=0, shear=10),
# 色の変動
transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2
),
# 基本処理
transforms.Resize((256,256), InterpolationMode.BILINEAR),
transforms.ToTensor(),
# 正規化(ImageNetのパラメータ)
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
三、モデルアーキテクチャの最適化:ネットワークを強化する#
1. 残差接続の追加(サンプルコード)#
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, padding=1),
nn.BatchNorm2d(in_channels),
nn.ReLU(),
nn.Conv2d(in_channels, in_channels, 3, padding=1),
nn.BatchNorm2d(in_channels)
)
def forward(self, x):
return x + self.conv(x) # 残差接続
class ImprovedUNet(NestedUNet):
def __init__(self, num_classes, input_channels):
super().__init__(num_classes, input_channels)
# 元の構造に残差ブロックを追加
self.down1.add_module("res_block", ResidualBlock(64))
2. 事前学習済みエンコーダの使用#
from torchvision.models import resnet34
class PretrainedUNet(nn.Module):
def __init__(self, num_classes):
super().__init__()
# ResNet34をエンコーダとして使用
self.encoder = resnet34(pretrained=True)
# デコーダ部分を変更...
四、混合精度トレーニング:速度と精度のバランス#
1. コア原理#
データタイプ | ビット数 | 数値範囲 | 適用シーン |
---|---|---|---|
FP32 | 32 ビット | ±1e-38 ~ ±3e38 | 勾配更新などの精密操作 |
FP16 | 16 ビット | ±6e-5 ~ ±6.5e4 | 行列乗算などの高速計算 |
2. コード実装(トレーニングループの修正)#
from torch.cuda.amp import GradScaler, autocast
def train():
scaler = GradScaler() # 新規追加
for epoch in range(epochs):
for inputs, masks in train_loader:
optimizer.zero_grad()
# 混合精度の前方計算
with autocast():
outputs = model(inputs)
loss = criterion(outputs, masks)
# 勾配のスケーリングと逆伝播
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
3. パフォーマンス比較#
指標 | FP32 トレーニング | 混合精度トレーニング | 向上幅 |
---|---|---|---|
トレーニング時間 / エポック | 58s | 23s | 2.5x |
メモリ使用量 | 9.8GB | 5.2GB | 47%↓ |
mIoU | 0.812 | 0.809 | 0.3%↓ |
五、損失関数の最適化:クラス不均衡の解決#
1. Dice Loss + CrossEntropy#
class DiceCELoss(nn.Module):
def __init__(self, weight=0.5):
super().__init__()
self.weight = weight
def forward(self, pred, target):
# CrossEntropy
ce = F.cross_entropy(pred, target)
# Dice
pred = torch.softmax(pred, dim=1)
target_onehot = F.one_hot(target, num_classes).permute(0,3,1,2)
intersection = (pred * target_onehot).sum()
union = pred.sum() + target_onehot.sum()
dice = 1 - (2*intersection + 1e-5)/(union + 1e-5)
return self.weight*ce + (1-self.weight)*dice
2. 異なる損失関数の効果比較#
損失関数 | mIoU | トレーニングの安定性 |
---|---|---|
CrossEntropy | 0.80 | 高 |
Dice+CE(1:1) | 0.83 | 中 |
Focal+CE | 0.82 | 低 |
六、完全なトレーニングプロセスの最適化#
1. 改善されたトレーニング設定#
# ハイパーパラメータの最適化
batch_size = 16 # 元の8からメモリ節約後に倍増
learning_rate = 3e-4
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=3e-4,
total_steps=num_epochs*len(train_loader)
)
2. トレーニングモニタリングの提案#
# 検証ループに指標計算を追加
with torch.no_grad():
tp = ((pred == target) & (target == 1)).sum()
fp = ((pred != target) & (target == 0)).sum()
iou = tp / (tp + fp + fn + 1e-7)
print(f"Val mIoU: {iou.mean():.4f}")
七、まとめ:最適化ロードマップ#
-
第一優先度
- データ拡張(空間変換 + 色の変動)
- BatchNorm 層の追加
-
進化的最適化
- 混合精度トレーニング
- 残差接続 / 事前学習済みエンコーダ
-
微調整
- 損失関数の組み合わせ
- 学習率スケジューリング戦略
この記事は Mix Space によって xLog に同期更新されました。
元のリンクは https://blog.kanes.top/posts/default/DeepLearningModelPrecisionOptimization