사전 학습 ViT 백본 + 자리별 분류 헤드
Vision Transformer(ViT)의 개념과 원리(패치, 토큰, Transformer Encoder, CLS 분류)는 앞선 글에서 다뤘다. 이번 글에서는 timm으로 사전 학습된 ViT 백본을 불러와, 이미지 한 장에서 여러 자리(예: 4자리 숫자) 를 동시에 예측하는 멀티 브랜치 분류기를 만드는 과정을 정리한다. 공통 ViT 위에 자리 수만큼 분류 브랜치를 두고, JSON 어노테이션·Albumentations 전처리·학습·추론까지 한 번에 따라 할 수 있도록 구성했다.
1. 태스크와 데이터 형식
한 이미지에 4자리 숫자가 쓰여 있고, 각 자리를 0~9 중 하나로 예측하는 멀티 라벨(자리별) 분류를 가정한다. 어노테이션은 annotations.json 형태로, 각 항목에 filename, labels(예: ["3", "1", "4", "2"])가 있다고 둔다.
import os
import json
data_root = "path/to/ViT"
with open(os.path.join(data_root, 'annotations.json'), 'r') as f:
annotations = json.load(f)
image_dir = os.path.join(data_root, 'images')
2. Dataset: 이미지 로드 + 자리별 타깃
labels가 자리별 정수 리스트이면, __getitem__에서 이미지와 함께 자리 개수만큼의 타깃을 반환한다.
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import torch
class JsonDataset(Dataset):
def __init__(self, image_dir, annotations, transform=None):
self.image_dir = image_dir
self.annotations = annotations
self.transform = transform
self.class_list = list(range(10))
self.num_classes = len(self.class_list)
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
annot = self.annotations[idx]
image_path = os.path.join(self.image_dir, annot['filename'])
image = Image.open(image_path).convert('RGB')
target_list = torch.tensor([int(c) for c in annot['labels']], dtype=torch.long)
if self.transform:
image = self.transform(image=np.array(image))['image']
return image, target_list
3. 전처리: Albumentations + ImageNet 정규화
학습 시에는 회전·이동·스케일을 넣고, 검증 시에는 리사이즈·패딩·정규화만 적용한다. ToTensorV2로 HWC→CHW, 넘파이→텐서 변환한다.
import albumentations as A
from albumentations.pytorch import ToTensorV2
hyper_params = {
'image_size': 224,
'train_batch_size': 32,
'val_batch_size': 16,
'num_epochs': 3,
'lr': 0.0001,
}
train_transform = A.Compose([
A.ShiftScaleRotate(rotate_limit=15, shift_limit=0.05, scale_limit=0.1, p=0.5, border_mode=0),
A.LongestMaxSize(max_size=hyper_params['image_size']),
A.PadIfNeeded(min_height=hyper_params['image_size'], min_width=hyper_params['image_size'], border_mode=0),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1.0),
ToTensorV2()
])
val_transform = A.Compose([
A.LongestMaxSize(max_size=hyper_params['image_size']),
A.PadIfNeeded(min_height=hyper_params['image_size'], min_width=hyper_params['image_size'], border_mode=0),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1.0),
ToTensorV2()
])
학습/검증용 어노테이션을 나눈 뒤 JsonDataset에 넣고, DataLoader로 배치를 만든다.
len_annot = len(annotations)
train_annot = annotations[:int(len_annot * 0.9)]
val_annot = annotations[int(len_annot * 0.9):]
train_dataset = JsonDataset(image_dir=image_dir, annotations=train_annot, transform=train_transform)
val_dataset = JsonDataset(image_dir=image_dir, annotations=val_annot, transform=val_transform)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, num_workers=4, batch_size=hyper_params['train_batch_size'], shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset, num_workers=4, batch_size=hyper_params['val_batch_size']
)
4. ViT 멀티 브랜치 분류 모델
timm.create_model로 ViT를 불러올 때 num_classes=0으로 두면 분류 헤드를 제거하고 마지막 hidden 차원(embed_dim) 만 반환한다. 이 출력을 자리 수만큼의 브랜치에 넣어 각각 CrossEntropy로 학습한다.
import torch.nn as nn
import timm
class ViTMultiBranchClassifier(nn.Module):
def __init__(self, num_classes, num_branches, model_name='vit_base_patch16_224', pretrained=True):
super().__init__()
self.vit = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
self.branches = nn.ModuleList([
nn.Sequential(
nn.Linear(self.vit.embed_dim, 256),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
for _ in range(num_branches)
])
def forward(self, x):
features = self.vit(x)
return [branch(features) for branch in self.branches]
num_branches=4, num_classes=train_dataset.num_classes(10)로 모델을 만들면 된다.
5. 학습 루프
outputs = model(images)→ 브랜치별 로짓 리스트.- 각 브랜치
branch_idx에 대해CrossEntropyLoss(output, targets[:, branch_idx])를 구해 합을 loss로 사용. optimizer.zero_grad()→total_loss.backward()→optimizer.step().
num_branches = 4
num_classes = train_dataset.num_classes
model = ViTMultiBranchClassifier(num_classes=num_classes, num_branches=num_branches)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=hyper_params['lr'])
criterion = nn.CrossEntropyLoss()
for epoch in range(hyper_params['num_epochs']):
model.train()
for images, targets in train_dataloader:
images, targets = images.to(device), targets.to(device)
outputs = model(images)
total_loss = sum(criterion(output, targets[:, i]) for i, output in enumerate(outputs))
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# 검증: 브랜치별 정확도 또는 4자리 전체 일치율 계산 후 best 모델 저장
검증 시에는 브랜치별 정확도를 따로 모아 평균을 내거나, 4자리 전체를 한 번에 맞춘 비율 등을 추가로 정의할 수 있다.
6. 추론과 시각화
학습된 모델을 로드한 뒤, 검증 배치에서 이미지와 예측 4자리를 뽑아 시각화한다.
model.eval()
class_list = [str(i) for i in range(10)]
with torch.no_grad():
for idx, (images, targets) in enumerate(val_dataloader):
images = images.to(device)
outputs = model(images)
pred_classes = []
for i, output in enumerate(outputs):
predicted = torch.argmax(output[0])
pred_classes.append(class_list[int(predicted)])
pred_str = "".join(pred_classes)
# input_image_list, pred_class_list 등에 저장 후 draw_images로 시각화
정규화된 텐서를 다시 이미지로 보려면 denormalize 후 to_pil_image로 복원하면 된다.
7. 정리
| 주제 | 핵심 포인트 |
|---|---|
| timm ViT | timm.create_model(..., num_classes=0)으로 백본만 쓰고, 위에 멀티 브랜치 분류기를 올릴 수 있다. |
| 멀티 브랜치 | 한 이미지에서 여러 자리(숫자 등)를 동시에 예측할 때, 공유 ViT + 자리별 Linear(또는 MLP)로 구현한다. |
| 데이터 | JSON 어노테이션(filename, labels) + 이미지 디렉터리로 JsonDataset 구성. |
| 전처리 | Albumentations로 리사이즈·패딩·정규화(ImageNet) 후 ToTensorV2로 텐서 변환. |
마치며
- ViT 개념·원리는 앞선 글에서, 실전 구현은 이 글에서 정리했다.
- timm ViT 백본에 태스크에 맞는 분류 헤드(단일/멀티 브랜치) 를 올리는 패턴은 다른 비전 태스크에도 그대로 적용할 수 있다.
'AI·머신러닝 > 딥러닝·비전' 카테고리의 다른 글
| WLASL 수화 인식 실습 - VideoDataset, R3D, 학습과 추론 (0) | 2026.02.03 |
|---|---|
| 동영상 데이터와 수화 인식 - 시공간 분석, 3D CNN, WLASL 데이터셋 (0) | 2026.01.28 |
| Vision Transformer(ViT) - 개념과 원리, 패치부터 분류까지 (0) | 2026.01.26 |
| 이안류 CCTV와 스타벅스 세그멘테이션 - AI Hub·COCO에서 YOLOv8까지 실습 (0) | 2026.01.08 |
| Object Detection - 개념, 전통 기법, YOLO, Pascal VOC 2007, mAP (0) | 2026.01.07 |