728x90
딥러닝 모델이 커지고, 학습 시간이 수십~수백 시간에 달하면서
중단에 대비하지 않은 학습은 매우 큰 리스크를 안게 됩니다.
이때 반드시 필요한 기술이 **Checkpointing(체크포인팅)**입니다.
이는 모델 학습의 중간 결과를 주기적으로 저장하여
중단 시 해당 시점부터 재시작이 가능하게 하는 전략입니다.
✅ Checkpointing이란?
Checkpoint는 학습 도중 모델 상태 및 옵티마이저 정보를 디스크에 저장하는 파일입니다.
→ 학습이 중단되더라도 저장된 시점에서 다시 이어서 학습할 수 있게 됩니다.
| 포함 내용 |
설명 |
| 모델 파라미터 | weight, bias 등 학습된 매개변수 |
| 옵티마이저 상태 | momentum, learning rate 등 |
| 학습 메타 정보 | 현재 Epoch, Step, Seed, RNG 상태 등 |
| 기타 | 스케줄러 상태, 로그 정보 등 |
✅ 왜 중요한가?
- 시간 절약: 40시간 학습 중 39시간에 중단되었을 때 처음부터 다시 시작할 필요 없음
- 비용 절감: 클라우드 GPU 사용 시 과금 최소화
- Fault-tolerance 향상: 학습 실패 → 자동 재시작 가능
- 성능 튜닝 반복 시에도 유용: 여러 실험 버전 간 동일 지점에서 시작 가능
✅ 저장 주기 전략
| 기준 | 예시 | 특징 |
| Epoch 단위 | every_n_epochs=1 | 모델 수렴에 따른 구분 명확 |
| Step 단위 | every_n_steps=1000 | 보다 세밀한 복구 가능 |
| Validation 성능 기준 | save_best_only=True | 가장 성능 좋은 지점만 보관 |
| 시간 기준 | every_n_minutes=30 | 외부 요인 대비용으로 활용 |
✅ PyTorch 기준 Checkpoint 코드 예시
# 저장
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, 'checkpoint.pt')
# 로딩
checkpoint = torch.load('checkpoint.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
✅ 고급 적용 전략
| 전략 | 설명 |
| Sharded Checkpoint | 모델을 여러 GPU에 분할 저장 (예: ZeRO) |
| Async Checkpointing | 연산과 동시에 백그라운드 저장 |
| Remote Checkpointing | S3, NFS 등에 저장하여 장애 복구 가능 |
| Versioned Checkpointing | 실험/롤백을 위한 버전 관리 구조 |
| Auto Resume | 스크립트 실행 시 자동으로 가장 최근 체크포인트 탐지 및 재시작 |
✅ Serving과의 연계
Checkpoint는 학습 뿐 아니라 모델 추론 서빙에도 활용됩니다.
- 학습 완료된 모델 weight를 .pt, .bin, .ckpt 형식으로 export
- 서빙 인프라에서 해당 모델을 로딩하여 API로 제공
- ONNX, TorchScript 등으로 변환하여 가속기별 최적화 적용 가능
✅ 마무리
Checkpointing은 단순 저장 기능이 아닙니다.
AI 인프라의 회복력(Resilience)을 보장하는 핵심 구성요소입니다.
불안정한 클러스터 환경, 예측 불가한 학습 실패 상황에서
“처음부터 다시”가 아니라 “이어서 학습”이 가능해야 합니다.
728x90