ML & DL

경량화 | QAT, Quantization Aware Training, 딥러닝 경량화, 모델 경량화, 네트워크 경량화, yolov8+QAT 코드

토오오끼 2025. 4. 20. 22:26
728x90
반응형

 

딥러닝 모델 경량화 기법 중 하나인 양자화(Quantization)는 딥러닝 모델의 숫자 표현을 줄여 모델을 더 작고, 빠르고, 효율적으로 만드는 것으로 flat32를 int8로 바꿔 모델을 경량화 할 수 있다.

 

양자화의 종류에는 크게 두가지가 있다.

1. PTQ(Post-Training Static Quantization)

- 학습 없이 양자화가 가능하며 calibration 데이터로 scale/zero point를 측정한다. 빠르고 간단히 양자화가 가능하지만 정확도 손실 가능성이 높다.

2. QAT(Quantization Aware Training)

- 학습 중에 양자화 효과를 반영하는 것으로 훈련으로 손실이 보정이 되어 정확도 손실이 적다.

 

그 중 QAT를 적용 해 볼 일이 있어 간단히 알아보고 적용 해 보았다.

 

◼️ QAT의 목적

학습 중에 양자화의 영향을 시뮬레이션(fake quantization) 하면서 모델을 학습하여, 실제 inference에서 사용하는 정수 (int8) 연산 기반 모델의 정확도 손실을 최소화하는 것이다.

 

◼️ QAT 동작 방식

QAT는 학습 도중 양자화를 시뮬레이션(fake quantization)하는 것으로 pytorch native quantization API를 사용하여 prepare_qat -> train -> convert 순으로 QAT를 구현 할 수 있다.

 

수식으로 보면

quantized = round((float_value / scale) + zero_point)
dequantized = (quantized - zero_point) * scale

이 과정을 학습 중에 계속 반복하며 backpropagation도 가능하도록 설계 되어 있다.

 

◾ 주요 인자

QuantStub / DeQuantStub

  • QuantStub: 모델 입력을 양자화 시뮬레이션하는 레이어
  • DeQuantStub: 양자화된 출력을 다시 FP32로 되돌리는 레이어
  • 이걸 통해 모델 전체의 양자화-비양자화 범위를 설정
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()

 

FakeQuantization Module

  • torch.fake_quantize.FakeQuantize 클래스 기반
  • 실제로 값은 FP32인데, 훈련 중에는 int8처럼 clamping + rounding 처리하여 backpropagation이 가능한 상태로 유지
  • forward 시
# 예: [0, 255] 범위로 quantization
x = clamp(x, min_val, max_val)
x = round((x - min_val) / scale) * scale + min_val

 

Observer

  • MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver 등
  • QAT 중 또는 calibration 단계에서 각 텐서의 최소/최대 범위 (min/max) 를 추적해서 이를 기반으로 quantization scale과 zero point를 결정

 

 

 ◾동작 방식

QConfig 설정

from torch.quantization import get_default_qat_qconfig

model.qconfig = get_default_qat_qconfig("qnnpack")
  • get_default_qat_qconfig() 호출 시, per_channel_weight_observer 사용 여부에 따라 양자화 방식이 선택 됨(정밀도 기준)
  • 주요 파라미터:
    • activation: observer type (예: MovingAverageMinMaxObserver)
      • weight 전체 텐서에 대해 하나의 scale/zero_point
    • weight: weight quantization 방식 (예: PerChannelMinMaxObserver) ← per_channel_weight_observer
      • weight의 각 채널에 대해 scale/zero_point 별도 적용
from torch.quantization import QConfig, default_observer, default_weight_observer

custom_qconfig = QConfig(
    activation=default_observer,  # MinMaxObserver
    weight=default_weight_observer  # PerChannelMinMaxObserver
)
  • fake_quant: 어떤 방식으로 양자화 시뮬레이션을 할지

 

모델 준비

torch.quantization.prepare_qat(model, inplace=True)
  • 여기서 PyTorch가 모델 내부 layer를 찾아서 FakeQuantize 모듈을 자동 삽입
  • Conv, Linear 등은 내부적으로 weight fake quant module도 삽입됨

 

QAT 학습

  • 일반적인 학습처럼 train loop 수행
  • 하지만 각 layer마다 FakeQuant 적용된 forward가 호출되어 양자화 시뮬레이션을 반영하면서 학습됨

 

convert() 호출

model.eval()
torch.quantization.convert(model, inplace=True)
  • FakeQuant → Quantized Conv2d 등으로 교체됨
  • int8 weight, bias, activation으로 구성된 모델이 완성됨
  • PyTorch의 quantized backend는 대부분 CPU 전용 구현임 (예: bgemm, qnnpack 등은 CPU에서만 동작)
    • backend를 qnnpack으로 두고 convert를 했다면 저장되는 모델은 cpu 연산만 가능함.

 


참고 원본 코드 : https://github.com/mmsori/yolov8-QAT

위 코드를 사용하여 yolov8 모델에 QAT 적용 해 보려고 했지만 아무래도 오래 된 코드이다 보니 이런저런 오류가 발생했다.

자잘한 오류를 수정하고 multi gpu를 사용하여 학습이 가능하도록 코드를 수정하였으며 pre-trained model을 로드 할 때 생기는 오류를 해결하기 위해 코드 수정을 했다.

추론 코드는 추론 시 dequantization 파트를 빼 먹었는지 추론 결과가 완전히 엉망으로 출력 돼서(... ㅠ) 수정이 필요하지만 일단 동작은 되는(ㅠㅠ...) 코드로 올려 두었다..! 누군가에겐 도움이 되길..

https://github.com/YOOHYOJEONG/yolov8-QAT

 

GitHub - YOOHYOJEONG/yolov8-QAT: QAT implementation on YOLOv8

QAT implementation on YOLOv8. Contribute to YOOHYOJEONG/yolov8-QAT development by creating an account on GitHub.

github.com

 

728x90
반응형