수민 '-'

플오그래밍

제가 작성하는 모든 글은 절대 상업적인 이용이 아니며, 그저 개인적인 공부 용도로만 사용하는 것임을 밝힙니다.

timm ViT로 멀티 브랜치 분류기 만들기 - 4자리 숫자 이미지 분류

사전 학습 ViT 백본 + 자리별 분류 헤드

Vision Transformer(ViT)의 개념과 원리(패치, 토큰, Transformer Encoder, CLS 분류)는 앞선 글에서 다뤘다. 이번 글에서는 timm으로 사전 학습된 ViT 백본을 불러와, 이미지 한 장에서 여러 자리(예: 4자리 숫자) 를 동시에 예측하는 멀티 브랜치 분류기를 만드는 과정을 정리한다. 공통 ViT 위에 자리 수만큼 분류 브랜치를 두고, JSON 어노테이션·Albumentations 전처리·학습·추론까지 한 번에 따라 할 수 있도록 구성했다.


1. 태스크와 데이터 형식

한 이미지에 4자리 숫자가 쓰여 있고, 각 자리를 0~9 중 하나로 예측하는 멀티 라벨(자리별) 분류를 가정한다. 어노테이션은 annotations.json 형태로, 각 항목에 filename, labels(예: ["3", "1", "4", "2"])가 있다고 둔다.

import os
import json

data_root = "path/to/ViT"
with open(os.path.join(data_root, 'annotations.json'), 'r') as f:
    annotations = json.load(f)

image_dir = os.path.join(data_root, 'images')

2. Dataset: 이미지 로드 + 자리별 타깃

labels가 자리별 정수 리스트이면, __getitem__에서 이미지와 함께 자리 개수만큼의 타깃을 반환한다.

from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import torch

class JsonDataset(Dataset):
    def __init__(self, image_dir, annotations, transform=None):
        self.image_dir = image_dir
        self.annotations = annotations
        self.transform = transform
        self.class_list = list(range(10))
        self.num_classes = len(self.class_list)

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        annot = self.annotations[idx]
        image_path = os.path.join(self.image_dir, annot['filename'])
        image = Image.open(image_path).convert('RGB')
        target_list = torch.tensor([int(c) for c in annot['labels']], dtype=torch.long)

        if self.transform:
            image = self.transform(image=np.array(image))['image']
        return image, target_list

3. 전처리: Albumentations + ImageNet 정규화

학습 시에는 회전·이동·스케일을 넣고, 검증 시에는 리사이즈·패딩·정규화만 적용한다. ToTensorV2로 HWC→CHW, 넘파이→텐서 변환한다.

import albumentations as A
from albumentations.pytorch import ToTensorV2

hyper_params = {
    'image_size': 224,
    'train_batch_size': 32,
    'val_batch_size': 16,
    'num_epochs': 3,
    'lr': 0.0001,
}

train_transform = A.Compose([
    A.ShiftScaleRotate(rotate_limit=15, shift_limit=0.05, scale_limit=0.1, p=0.5, border_mode=0),
    A.LongestMaxSize(max_size=hyper_params['image_size']),
    A.PadIfNeeded(min_height=hyper_params['image_size'], min_width=hyper_params['image_size'], border_mode=0),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1.0),
    ToTensorV2()
])

val_transform = A.Compose([
    A.LongestMaxSize(max_size=hyper_params['image_size']),
    A.PadIfNeeded(min_height=hyper_params['image_size'], min_width=hyper_params['image_size'], border_mode=0),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1.0),
    ToTensorV2()
])

학습/검증용 어노테이션을 나눈 뒤 JsonDataset에 넣고, DataLoader로 배치를 만든다.

len_annot = len(annotations)
train_annot = annotations[:int(len_annot * 0.9)]
val_annot = annotations[int(len_annot * 0.9):]

train_dataset = JsonDataset(image_dir=image_dir, annotations=train_annot, transform=train_transform)
val_dataset = JsonDataset(image_dir=image_dir, annotations=val_annot, transform=val_transform)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, num_workers=4, batch_size=hyper_params['train_batch_size'], shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, num_workers=4, batch_size=hyper_params['val_batch_size']
)

4. ViT 멀티 브랜치 분류 모델

timm.create_model로 ViT를 불러올 때 num_classes=0으로 두면 분류 헤드를 제거하고 마지막 hidden 차원(embed_dim) 만 반환한다. 이 출력을 자리 수만큼의 브랜치에 넣어 각각 CrossEntropy로 학습한다.

import torch.nn as nn
import timm

class ViTMultiBranchClassifier(nn.Module):
    def __init__(self, num_classes, num_branches, model_name='vit_base_patch16_224', pretrained=True):
        super().__init__()
        self.vit = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.vit.embed_dim, 256),
                nn.GELU(),
                nn.Dropout(0.3),
                nn.Linear(256, num_classes)
            )
            for _ in range(num_branches)
        ])

    def forward(self, x):
        features = self.vit(x)
        return [branch(features) for branch in self.branches]

num_branches=4, num_classes=train_dataset.num_classes(10)로 모델을 만들면 된다.


5. 학습 루프

  • outputs = model(images) → 브랜치별 로짓 리스트.
  • 각 브랜치 branch_idx에 대해 CrossEntropyLoss(output, targets[:, branch_idx])를 구해 을 loss로 사용.
  • optimizer.zero_grad()total_loss.backward()optimizer.step().
num_branches = 4
num_classes = train_dataset.num_classes
model = ViTMultiBranchClassifier(num_classes=num_classes, num_branches=num_branches)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=hyper_params['lr'])
criterion = nn.CrossEntropyLoss()

for epoch in range(hyper_params['num_epochs']):
    model.train()
    for images, targets in train_dataloader:
        images, targets = images.to(device), targets.to(device)
        outputs = model(images)
        total_loss = sum(criterion(output, targets[:, i]) for i, output in enumerate(outputs))
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
    # 검증: 브랜치별 정확도 또는 4자리 전체 일치율 계산 후 best 모델 저장

검증 시에는 브랜치별 정확도를 따로 모아 평균을 내거나, 4자리 전체를 한 번에 맞춘 비율 등을 추가로 정의할 수 있다.


6. 추론과 시각화

학습된 모델을 로드한 뒤, 검증 배치에서 이미지와 예측 4자리를 뽑아 시각화한다.

model.eval()
class_list = [str(i) for i in range(10)]

with torch.no_grad():
    for idx, (images, targets) in enumerate(val_dataloader):
        images = images.to(device)
        outputs = model(images)
        pred_classes = []
        for i, output in enumerate(outputs):
            predicted = torch.argmax(output[0])
            pred_classes.append(class_list[int(predicted)])
        pred_str = "".join(pred_classes)
        # input_image_list, pred_class_list 등에 저장 후 draw_images로 시각화

정규화된 텐서를 다시 이미지로 보려면 denormalizeto_pil_image로 복원하면 된다.


7. 정리

주제 핵심 포인트
timm ViT timm.create_model(..., num_classes=0)으로 백본만 쓰고, 위에 멀티 브랜치 분류기를 올릴 수 있다.
멀티 브랜치 한 이미지에서 여러 자리(숫자 등)를 동시에 예측할 때, 공유 ViT + 자리별 Linear(또는 MLP)로 구현한다.
데이터 JSON 어노테이션(filename, labels) + 이미지 디렉터리로 JsonDataset 구성.
전처리 Albumentations로 리사이즈·패딩·정규화(ImageNet) 후 ToTensorV2로 텐서 변환.

마치며

  • ViT 개념·원리는 앞선 글에서, 실전 구현은 이 글에서 정리했다.
  • timm ViT 백본에 태스크에 맞는 분류 헤드(단일/멀티 브랜치) 를 올리는 패턴은 다른 비전 태스크에도 그대로 적용할 수 있다.