수민 '-'

플오그래밍

제가 작성하는 모든 글은 절대 상업적인 이용이 아니며, 그저 개인적인 공부 용도로만 사용하는 것임을 밝힙니다.

A2C - Advantage Actor-Critic 동기식 병렬 학습

A3C를 동기식으로 단순화한 Actor-Critic

A2C(Advantage Actor-Critic)는 이름 그대로 Advantage 기반 Actor-Critic 구조를 사용하면서, A3C처럼 완전히 비동기로 학습하지 않고 동기식(synchronous) 으로 여러 환경에서 수집한 데이터를 한 번에 batch로 처리하는 강화학습 알고리즘이다.
여러 worker 환경을 병렬로 돌리면서도, 일정 step마다 모든 경험을 모아 한 번에 모델을 업데이트하기 때문에 학습이 안정적이고 GPU로 처리하기도 쉽다. 이 구조는 이후 PPO 등 많은 실전 알고리즘의 기반이 된다.

이 글에서는 A3C와의 차이, A2C 구조 및 장점, 그리고 CartPole-v1 환경을 대상으로 한 동기식 병렬 Actor-Critic 구현 코드까지 정리한다.


1. A3C vs A2C 한눈에 비교

먼저 A3C와 A2C의 가장 큰 차이를 그림으로 정리해보자.

1-1. A3C – 비동기 업데이트

Worker1 -> gradient -> Global
Worker2 -> gradient -> Global
Worker3 -> gradient -> Global
...
  • 각 worker가 자기 local 모델 기준으로 gradient를 계산하고,
  • 그 결과를 전역 모델(global network) 에 바로 적용한다.
  • 서로를 기다리지 않기 때문에 비동기(asynchronous) 이고, 그만큼 빠르지만 gradient가 서로 덮어쓰는 문제가 생길 수 있다.

1-2. A2C – 동기식 업데이트

Worker1 -> 환경 탐험
Worker2 -> 환경 탐험
Worker3 -> 환경 탐험
Worker4 -> 환경 탐험
        ↓
    데이터 모음 (batch)
        ↓
Actor + Critic 동기 업데이트
  1. 여러 worker(환경)가 동시에 rollout을 수행해 데이터를 모은다.
  2. 일정 step마다 모든 worker의 데이터를 모아 하나의 batch 로 만든다.
  3. 그 batch를 이용해 Actor와 Critic을 한 번에 동기적으로 업데이트한다.

덕분에:

  • gradient가 서로 충돌하지 않고,
  • 큰 batch를 만들어 GPU로 돌리기 좋으며,
  • 한 번의 업데이트가 더 많은 데이터를 보고 학습하므로 신호가 더 안정적이다.

2. A3C의 한계와 A2C가 필요한 이유

2-1. Gradient 충돌(Interference)

여러 worker가 동시에 글로벌 모델을 업데이트하면 다음과 같은 일이 생길 수 있다.

  • worker1이 막 gradient를 적용해 global 파라미터를 바꾼 직후,
  • worker2가 조금 전에 백업해둔 파라미터 기준으로 계산한 gradient를 그대로 덮어써 버린다.

이런 현상을 gradient interference(충돌) 라고 부르고, 이 때문에 학습이 비효율적이거나 불안정해질 수 있다.

2-2. 학습 불안정과 재현성 문제

  • 비동기 업데이트 특성상 학습 순서가 계속 바뀐다.
  • global 파라미터가 예측하기 어려운 타이밍에 바뀌기 때문에, 결과가 run마다 조금씩 달라지기 쉽다.

연구용 실험이라면 괜찮을 수 있지만, 실무에서는 “동일 설정이면 비슷한 결과가 나와야 한다” 는 재현성이 중요하기 때문에 단점이 된다.

2-3. GPU 활용 어려움

  • A3C는 구조적으로 CPU 멀티프로세싱 에 초점을 맞춘 구조다.
  • GPU는 큰 batch를 한 번에 처리할수록 효율이 좋은데,
  • A3C처럼 작은 gradient 업데이트를 자주 보내는 패턴은 GPU를 제대로 활용하기 어렵다.

이 세 가지 이유 때문에, “A3C의 멀티 환경 병렬 수집”이라는 장점은 유지하면서도, “업데이트는 한 번에 동기식으로” 처리하는 A2C가 등장했다.


3. A2C 구조: 동기식 Advantage Actor-Critic

A2C의 아이디어는 단순하다.

경험 수집은 여럿이 병렬로 하되,
정책/가치 업데이트는 한 번에 동기식으로 하자.

정리하면 다음과 같다.

  1. 여러 환경(worker)이 동시에 rollout을 진행한다.
  2. 일정 step마다 (state, action, reward, done, next_state) 를 모두 모아 큰 batch를 만든다.
  3. 이 batch로 Actor와 Critic을 함께 업데이트한다.

Actor-Critic 구조 자체는 A3C와 동일하고, 변경되는 것은 업데이트 타이밍과 방식(동기/비동기) 이다.


4. A2C의 장점 정리

4-1. 학습 안정성

  • 모든 gradient가 같은 시점의 파라미터 기준으로 계산된다.
  • A3C처럼 worker 간 gradient가 서로 덮어쓰는 문제가 없다.

4-2. GPU 활용 용이

  • 여러 환경에서 모은 데이터를 하나의 큰 batch로 묶어 GPU에서 처리할 수 있다.
  • 대규모 네트워크, 복잡한 환경에서도 효율적인 학습이 가능하다.

4-3. 구현 단순성

  • 공유 메모리, 락(lock), 비동기 통신 같은 난이도 높은 부분을 많이 줄일 수 있다.
  • “여러 환경을 하나의 벡터화된 환경처럼 다루는 패턴”만 구현하면 되기 때문에 코드 이해도 더 쉽다.

4-4. 재현성 증가

  • 업데이트 순서가 고정된 동기식 방식이라, 같은 시드·같은 설정이라면 결과가 훨씬 더 잘 재현된다.

5. CartPole-v1 A2C 구현 – ParallelEnv와 메인 학습 루프

이번에는 CartPole-v1 환경을 대상으로, A2C 구조를 어떻게 코드로 옮기는지 살펴본다.

  • ParallelEnv 클래스로 여러 환경을 한 번에 처리하고,
  • 메인 프로세스에서는 이를 하나의 벡터화된 환경처럼 다루면서,
  • Actor-Critic 모델을 동기식 batch 업데이트로 학습하는 구조다.

5-1. Actor-Critic 모델 정의

import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import torch.multiprocessing as mp
import numpy as np

# Hyperparameters
n_train_processes = 3
learning_rate = 0.0002
update_interval = 5
gamma = 0.98
max_train_steps = 60000
PRINT_INTERVAL = update_interval * 100


class ActorCritic(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 256)
        self.fc_pi = nn.Linear(256, 2)
        self.fc_v = nn.Linear(256, 1)

    def pi(self, x, softmax_dim=1):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob

    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v
  • 상태 차원 4(CartPole) → 은닉층 256 →
    • 정책 네트워크: 행동 2개(왼쪽/오른쪽)에 대한 확률
    • 가치 네트워크: 스칼라 (V(s))

6. worker 프로세스 – 환경을 대신 step 해주기

각 worker는 실제 gym 환경을 가지고, 메인 프로세스로부터 받은 명령에 따라 step을 수행한다.

def worker(worker_id, master_end, worker_end):
    # worker 프로세스에서는 master_end를 사용하지 않음
    master_end.close()
    env = gym.make("CartPole-v1")

    # worker마다 다른 시드로 시작 (경험 다양성)
    obs, _ = env.reset(seed=worker_id)

    while True:
        # 메인 프로세스로부터 명령(step, reset, close 등) 수신
        cmd, data = worker_end.recv()

        if cmd == "step":
            obs, reward, terminated, truncated, info = env.step(int(data))
            done = terminated or truncated

            if done:
                obs, _ = env.reset()

            # step 결과를 메인 프로세스로 전달
            worker_end.send((obs, reward, done, info))

        elif cmd == "reset":
            obs, _ = env.reset()
            worker_end.send(obs)

        elif cmd == "close":
            env.close()
            worker_end.close()
            break

        elif cmd == "get_spaces":
            worker_end.send((env.observation_space, env.action_space))

        else:
            raise NotImplementedError(f"Unknown command: {cmd}")
  • 메인 프로세스는 환경을 직접 건드리지 않고, worker에게 명령을 보내는 역할만 한다.
  • step, reset, close 등을 파이프를 통해 전달하고, worker는 그 결과를 다시 돌려준다.

7. ParallelEnv – 여러 환경을 한꺼번에 다루기

ParallelEnv는 여러 개의 worker를 통합해서, 마치 하나의 벡터화된 환경처럼 사용할 수 있게 해준다.

class ParallelEnv:
    def __init__(self, n_train_processes):
        self.nenvs = n_train_processes
        self.waiting = False
        self.closed = False
        self.workers = []

        # 환경 개수만큼 파이프 생성
        master_ends, worker_ends = zip(*[mp.Pipe() for _ in range(self.nenvs)])
        self.master_ends = master_ends
        self.worker_ends = worker_ends

        # 각 환경마다 worker 프로세스 생성
        for worker_id, (master_end, worker_end) in enumerate(zip(master_ends, worker_ends)):
            p = mp.Process(target=worker, args=(worker_id, master_end, worker_end))
            p.daemon = True
            p.start()
            self.workers.append(p)

        # master는 worker_end를 사용하지 않으므로 닫음
        for worker_end in worker_ends:
            worker_end.close()
    def step_async(self, actions):
        # 여러 환경에 행동을 먼저 보내는 함수
        for master_end, action in zip(self.master_ends, actions):
            master_end.send(("step", int(action)))
        self.waiting = True

    def step_wait(self):
        # 모든 worker의 step 결과를 받아옴
        results = [master_end.recv() for master_end in self.master_ends]
        self.waiting = False
        obs, rews, dones, infos = zip(*results)
        return (
            np.stack(obs).astype(np.float32),
            np.array(rews, dtype=np.float32),
            np.array(dones, dtype=np.bool_),
            infos,
        )

    def reset(self):
        for master_end in self.master_ends:
            master_end.send(("reset", None))
        return np.stack([master_end.recv() for master_end in self.master_ends]).astype(np.float32)

    def step(self, actions):
        self.step_async(actions)
        return self.step_wait()
  • step(actions)를 호출하면,
    • 각 환경에 액션을 보내고,
    • 모든 worker에서 결과를 받아,
    • (nenvs, obs_dim) 형태의 관측값 배열과 reward, done 벡터를 반환한다.
  • 메인 코드 입장에서는 여러 환경이 벡터화된 하나의 환경처럼 보이게 된다.
    def close(self):
        if self.closed:
            return

        if self.waiting:
            _ = [master_end.recv() for master_end in self.master_ends]

        for master_end in self.master_ends:
            master_end.send(("close", None))

        for worker in self.workers:
            worker.join()

        self.closed = True
  • 학습이 끝난 뒤에는 모든 worker와 환경을 정리해 준다.

8. 테스트 함수 – 현재 정책 성능 확인

def test(step_idx, model):
    env = gym.make("CartPole-v1")
    score = 0.0
    num_test = 10

    for _ in range(num_test):
        s, _ = env.reset()
        done = False

        while not done:
            with torch.no_grad():
                prob = model.pi(torch.from_numpy(s).float(), softmax_dim=0)
            a = Categorical(prob).sample().item()

            s_prime, r, terminated, truncated, info = env.step(a)
            done = terminated or truncated

            s = s_prime
            score += r

    print(f"Step # : {step_idx}, avg score : {score / num_test:.1f}")
    env.close()
  • 일정 step마다 현재 Actor-Critic 모델의 성능을 테스트한다.
  • softmax_dim=0 으로 단일 상태에 대한 행동 확률을 구하고, 샘플링해서 episode를 진행한다.

9. n-step TD target 계산 – compute_target

def compute_target(v_final, r_lst, mask_lst):
    G = v_final.reshape(-1)  # 마지막 상태 가치 (배열)
    td_target = []

    # 보상과 마스크를 뒤에서부터 거꾸로 순회하면서 n-step return 계산
    for r, mask in zip(r_lst[::-1], mask_lst[::-1]):
        G = r + gamma * G * mask
        td_target.append(G)

    td_target.reverse()
    return torch.tensor(np.array(td_target), dtype=torch.float32)
  • v_final: rollout 마지막 상태들의 가치 (V(s_{\text{final}}))
  • r_lst: 각 step마다의 보상 벡터
  • mask_lst: done이면 0, 아니면 1
  • 뒤에서부터 누적하면서 (Gt = r_t + \gamma G{t+1}) 형태로 n-step Return을 만든 다음, 앞 방향 순서로 되돌린다.

10. 메인 학습 루프 – A2C의 핵심

def main():
    envs = ParallelEnv(n_train_processes)

    model = ActorCritic()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    step_idx = 0
    s = envs.reset()

    while step_idx < max_train_steps:
        s_lst, a_lst, r_lst, mask_lst = [], [], [], []

        for _ in range(update_interval):
            with torch.no_grad():
                prob = model.pi(torch.from_numpy(s).float(), softmax_dim=1)

            a = Categorical(prob).sample().numpy()
            s_prime, r, done, info = envs.step(a)

            s_lst.append(s.copy())
            a_lst.append(a.copy())
            r_lst.append(r / 100.0)
            mask_lst.append(1.0 - done.astype(np.float32))

            s = s_prime
            step_idx += 1
  • (nenvs, obs_dim) 형태의 상태 s에서 정책을 돌려, 각 환경마다 하나씩 행동을 샘플링한다.
  • envs.step(a) 로 모든 환경을 동시에 한 스텝 진행한다.
  • reward를 r / 100.0 으로 스케일링한다.
  • mask = 1 - done 으로, episode가 끝난 환경에 대해서는 이후 가치가 0이 되도록 만든다.
        s_final = torch.from_numpy(s_prime).float()
        with torch.no_grad():
            # 마지막 상태의 가치 (n-step return 시작점)
            v_final = model.v(s_final).cpu().numpy()

        td_target = compute_target(v_final, r_lst, mask_lst)

        td_target_vec = td_target.reshape(-1)
        # 상태 리스트를 텐서로 변환 (전체 샘플 수, 4)
        s_vec = torch.tensor(np.array(s_lst), dtype=torch.float32).reshape(-1, 4)
        # 행동을 (N, 1) 형태로 변환
        a_vec = torch.tensor(np.array(a_lst), dtype=torch.long).reshape(-1).unsqueeze(1)

        values = model.v(s_vec).reshape(-1)
        advantage = td_target_vec - values

        pi = model.pi(s_vec, softmax_dim=1)
        pi_a = pi.gather(1, a_vec).reshape(-1)

        loss = -(torch.log(pi_a + 1e-8) * advantage.detach()).mean() \
               + F.smooth_l1_loss(values, td_target_vec)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step_idx % PRINT_INTERVAL == 0:
            test(step_idx, model)

    envs.close()
  • s_lst, a_lst, r_lst 등을 모두 펼쳐서 하나의 큰 batch 로 만든다.
  • values = V(s), td_target_vec 을 이용해 Advantage를 계산한다.
  • 정책 손실: (-\log \pi(a|s) \cdot \text{Advantage}) 의 평균
  • 가치 손실: (V(s))와 TD target 사이의 Huber loss
  • 두 손실을 합쳐 backward와 optimizer step을 수행한다.
if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    main()
  • spawn 모드는 맥OS·윈도우 등에서 멀티프로세싱을 안전하게 쓰기 위해 필수적인 설정이다.

11. A3C vs A2C 요약 정리

항목 A3C A2C
업데이트 방식 비동기(asynchronous) 동기(synchronous)
경험 수집 여러 worker가 환경 병렬 탐험 여러 worker가 환경 병렬 탐험
gradient 적용 각 worker가 global에 바로 적용 모든 worker 데이터를 모아서 한 번에 적용
안정성 gradient interference 가능 gradient 충돌 없음, 더 안정적
GPU 활용 CPU 멀티프로세싱 위주 큰 batch로 GPU 학습에 적합
구현 난이도 공유 메모리·비동기 처리 등으로 상대적 고난도 상대적으로 단순, 코드 구조가 명확
후속 알고리즘 A3C 자체 사용은 줄어드는 추세 PPO 등 많은 알고리즘이 A2C 스타일을 기반으로 발전

마치며

  • A3C는 여러 에이전트를 비동기로 돌려 global 모델을 업데이트하는 구조로, 데이터 효율과 탐험 성능을 크게 끌어올렸다.
  • 하지만 비동기 구조 특성상 gradient 충돌, 재현성, GPU 활용 측면에서 한계가 있었고, 이를 해결하기 위해 등장한 것이 A2C다.
  • A2C처럼 여러 환경에서 동시에 병렬로 경험을 수집하면서도, 정책/가치 업데이트는 동기식 batch 로 처리하는 패턴은 이후 PPO, A2C 변형, Curiosity 기반 탐험 등 다양한 실전 RL 코드에서 거의 공통으로 등장하는 기본 템플릿이라고 보면 된다.