征程 6 | 工具链 QAT ObserverBase 源码解析
1. 概述
ObserverBase 是 horizon_plugin_pytorch 量化框架中所有 Observer 的抽象基类。它定义了量化校准器的统一接口和核心功能,为各种量化策略(MinMax、MSE、KL 等)提供了基础架构。
2. ABCMeta 深度解析
2.1 Python 元类机制
在 Python 中,类也是对象,类是由元类(metaclass)创建的:
默认情况下,所有类都由元类创建。当指定 metaclass=ABCMeta 时,类的创建过程由 ABCMeta 控制。
示例如下:
from abc import ABCMeta, abstractmethod class ObserverBase(torch.nn.Module, metaclass=ABCMeta): @abstractmethod def forward(self, x): pass2.2 @abstractmethod 装饰器
def abstractmethod(funcobj): """标记方法为抽象方法""" funcobj.__isabstractmethod__ = True # 仅设置标志位 return funcobj2.3 ObserverBase 中的应用
# 基类定义抽象方法 class ObserverBase(torch.nn.Module, metaclass=ABCMeta): @abstractmethod def forward(self, x): pass # ObserverBase.__abstractmethods__ = frozenset({'forward'}) # 子类实现 class MinMaxObserver(ObserverBase): def forward(self, x_orig): return x_orig # MinMaxObserver.__abstractmethods__ = frozenset() → 可实例化3. ObserverBase 完整源码
class ObserverBase(torch.nn.Module, metaclass=ABCMeta): r"""Base observer Module. Any observer implementation should derive from this class. Concrete observers should follow the same API. In forward, they will update the statistics of the observed Tensor. And they should provide a `calculate_qparams` function that computes the quantization parameters given the collected statistics. Args: averaging_constant: Averaging constant for min/max. ch_axis: Channel axis. dtype: Quantized data type. qscheme: Quantization scheme to be used. quant_min: Min quantization value. Will follow dtype if unspecified. quant_max: Max quantization value. Will follow dtype if unspecified. is_sync_quantize: If sync statistics when training with multiple devices. factory_kwargs: kwargs which are passed to factory functions for min_val and max_val. """ _version = 3 eps: torch.Tensor min_val: torch.Tensor max_val: torch.Tensor is_sync_quantize: Optional[bool] = True @typechecked def __init__( self, averaging_constant: float = 0.01, ch_axis: int = -1, dtype: Union[torch.dtype, QuantDType] = qint8, qscheme: torch.qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: Optional[bool] = None, factory_kwargs: Dict = None, compute_scale_strategy=ComputeScaleStrategy.STATISTIC, ): super(ObserverBase, self).__init__() if qscheme == torch.per_channel_symmetric: assert ( ch_axis >= 0 ), "ch_axis should be non-negative when using per_channel_symmetric qcsheme" else: assert ( ch_axis < 0 ), "ch_axis should be negative when using per_tensor_symmetric qcsheme" dtype = get_horizon_quant_dtype(dtype) assert qscheme in ( torch.per_tensor_symmetric, torch.per_channel_symmetric, ), ( "only support per_tensor_symmetric and per_channel_symmetric " "qscheme" ) self.averaging_constant = averaging_constant self.ch_axis = ch_axis self.dtype = dtype self.qscheme = qscheme self._set_quant_min_max(self.dtype, quant_min, quant_max) if is_sync_quantize is not None: self.is_sync_quantize = is_sync_quantize self.compute_scale_strategy = compute_scale_strategy factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) self.register_buffer( "eps", torch.tensor([torch.finfo(torch.float32).eps], **factory_kwargs), ) self.register_buffer("min_val", torch.tensor([], **factory_kwargs)) self.register_buffer("max_val", torch.tensor([], **factory_kwargs)) def _set_quant_min_max( self, dtype, quant_min=None, quant_max=None, ): if (quant_min is not None) and (quant_max is not None): assert quant_min < quant_max, ( "qmin must be strictly less than qmax for user-specified " "quantization range." ) assert ( quant_min <= 0 <= quant_max ), "Used-specified quantization range must include 0." assert qinfo(dtype).min <= quant_min, "quant_min out of bound" assert quant_max <= qinfo(dtype).max, "quant_max out of bound" self.quant_min, self.quant_max = quant_min, quant_max else: self.quant_min, self.quant_max = ( qinfo(self.dtype).min, qinfo(self.dtype).max, ) def reset_dtype(self, dtype): dtype = get_horizon_quant_dtype(dtype) if dtype == self.dtype: return self.dtype = dtype self._set_quant_min_max(self.dtype) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): # buffers has been renamed from min/max_vals to min/max_val buffer_name_mapping = {"min_vals": "min_val", "max_vals": "max_val"} for old_name in buffer_name_mapping: k = prefix + old_name if k in state_dict: v = state_dict.pop(k) state_dict[prefix + buffer_name_mapping[old_name]] = v eps_key = prefix + "eps" if eps_key not in state_dict: # eps was moved to a buffer in version 2 eps = torch.tensor([torch.finfo(torch.float32).eps]) state_dict[eps_key] = eps local_state = ["min_val", "max_val"] for name in local_state: key = prefix + name if key in state_dict: # if ndim=0, make it ndim=1 state_dict[key] = state_dict[key].reshape(-1) val = state_dict[key] # Custom handling to allow loading min_val or max_val # of size N into uninitialized buffers of size 0. The # buffers are resized here, and the values are copied in # the default state_dict loading code of the parent. if name == "min_val" and hasattr(self, "min_val"): self.min_val.resize_(val.shape) elif hasattr(self, "max_val"): self.max_val.resize_(val.shape) super(ObserverBase, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) def _load_from_state_dict_script( self, state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], prefix: str, local_metadata: Dict[str, torch.Tensor], strict: bool, missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str], ): self._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) def sync_minmax(self, min_val, max_val): if dist.is_initialized() and min_val.is_cuda: dist.all_reduce(min_val, op=dist.ReduceOp.MIN) dist.all_reduce(max_val, op=dist.ReduceOp.MAX) def calculate_qparams(self): r"""Calculate the quantization parameters. Returns: scales: Scales tensor of shape (#channels,) zero_points: Zero points tensor of shape (#channels,) """ if self.min_val.numel() == 0 or self.max_val.numel() == 0: warnings.warn( "Must run observer before calling calculate_qparams. " "Returning default scale and zero point. " "This is an expected behavior if you use KLObserver " "and set 1 < update_interval <= total steps. ", ) return torch.tensor( [1.0], device=self.min_val.device ), torch.tensor([0], device=self.min_val.device) scale = _compute_scale_symmetric( self.min_val, self.max_val, self.quant_min, self.quant_max, self.eps, self.compute_scale_strategy, ) return scale, None def repr_msgs(self): msges = [] # only print minmax value for per tensor if hasattr(self, "min_val") and self.min_val.numel() == 1: msges.append("min_val={}".format(self.min_val.item())) if hasattr(self, "max_val") and self.max_val.numel() == 1: msges.append("max_val={}".format(self.max_val.item())) return msges def extra_repr(self): return ",".join(self.repr_msgs()) @abstractmethod def forward(self, x): pass with_args = classmethod(_with_args)4. 核心属性详解
4.1 量化配置属性
# 基础量化参数 self.averaging_constant: float # 移动平均系数 self.ch_axis: int # 通道轴 (per_channel量化时使用) self.dtype: QuantDType # 量化数据类型 (qint8, qint4等) self.qscheme: torch.qscheme # 量化方案 (per_tensor/per_channel) self.quant_min: int # 量化最小值 self.quant_max: int # 量化最大值 self.is_sync_quantize: bool # 多卡同步统计量 self.compute_scale_strategy # scale计算策略 (STATISTIC/POT/FP16等)4.2 统计量缓冲区
self.register_buffer("eps", torch.tensor([torch.finfo(torch.float32).eps])) self.register_buffer("min_val", torch.tensor([])) self.register_buffer("max_val", torch.tensor([]))使用 register_buffer 注册的原因:
- 不参与梯度计算:统计量不是模型参数
- 随模型迁移设备:model.cuda() 时自动迁移
- 可保存到 state_dict:校准结果可持久化
5. 核心方法详解
5.1 init - 初始化
参数说明:
| 参数 | 默认值 | 说明 |
|---|---|---|
| averaging_constant | 0.01 | 移动平均系数,值越大当前 batch 权重越高 |
| ch_axis | -1 | 通道轴,负数表示 per_tensor,非负表示 per_channel |
| dtype | qint8 | 量化数据类型 |
| qscheme | per_tensor_symmetric | 量化方案 |
| quant_min/max | None | 自定义量化范围,None 时根据 dtype 自动设置 |
| is_sync_quantize | TRUE | 多卡训练时是否同步统计量 |
关键校验逻辑:
# per_channel 必须指定有效的 ch_axis if qscheme == torch.per_channel_symmetric: assert ch_axis >= 0, "ch_axis should be non-negative" else: assert ch_axis < 0, "ch_axis should be negative for per_tensor" # 仅支持对称量化 assert qscheme in ( torch.per_tensor_symmetric, torch.per_channel_symmetric, )5.2 forward - 更新统计信息(抽象方法)
设计意图:
- 子类必须实现此方法(由 ABCMeta 强制)
- 在校准阶段,每个 forward pass 收集激活值的统计信息
- 返回原始输入(不修改数据流)
典型实现模式:
def forward(self, x_orig): # 1. 计算当前 batch 的统计量 min_val_cur, max_val_cur = compute_statistics(x_orig) # 2. 多卡同步(可选) if self.is_sync_quantize: self.sync_minmax(min_val_cur, max_val_cur) # 3. 更新累计统计量(移动平均) self.min_val = update_statistics(self.min_val, min_val_cur) self.max_val = update_statistics(self.max_val, max_val_cur) return x_orig # 原样返回,不干扰前向传播5.3 calculate_qparams - 计算量化参数
核心计算逻辑(_compute_scale_symmetric):
def _compute_scale_symmetric(min_val, max_val, quant_min, quant_max, eps, strategy): # 对称量化公式:scale = max(|min|, |max|) / (quant_range / 2) scale = ( torch.max(-min_val, max_val) .clamp_min(0) .div(float(quant_max - quant_min) / 2) .clamp_min(eps) ) # 可选的 scale 约束策略 if strategy == ComputeScaleStrategy.KPOT: # K-POT (可训练POT) scale = k_pot_scale(scale) elif strategy == ComputeScaleStrategy.POT: # Power-of-Two scale = 2 ** torch.ceil(torch.log2(scale)) elif strategy == ComputeScaleStrategy.FP16: # FP16 精度 scale = _get_fp16_scale(scale) return scale5.4 sync_minmax - 多卡同步
def sync_minmax(self, min_val, max_val): if dist.is_initialized() and min_val.is_cuda: dist.all_reduce(min_val, op=dist.ReduceOp.MIN) dist.all_reduce(max_val, op=dist.ReduceOp.MAX)原理:
- 使用 all_reduce 聚合多卡的统计量
- MIN 操作取所有卡的最小值
- MAX 操作取所有卡的最大值
- 确保多卡训练时校准结果一致
5.5 _load_from_state_dict - 状态加载
关键功能:
- 版本兼容(处理旧版名称 min_vals → min_val)
- 动态调整 buffer 大小
- 支持从校准模型加载参数到 QAT 模型
6. 类继承体系
ObserverBase (抽象基类) │ ├── MinMaxObserver # 移动平均 min/max 统计 │ │ │ └── ClipObserver # 带截断的 min/max 统计 │ ├── FixedScaleObserver # 固定 scale(不统计) │ ├── PercentileObserver # 百分位统计 │ ├── MSEObserver # 最小化 MSE 搜索最优 scale │ ├── KLObserver # KL 散度校准 │ ├── MixObserver # 混合多种方法 │ └── HistogramObserver # 直方图统计(支持多种度量)7. 设计亮点
- 统一接口:所有 Observer 遵循相同的 API,便于替换和扩展
- 抽象基类约束:通过 ABCMeta 强制子类实现 forward 方法
- 状态持久化:统计量作为 buffer 保存,支持校准结果复用
- 分布式支持:内置多卡同步机制
- 版本兼容:_load_from_state_dict 处理历史版本兼容
- 灵活配置:支持多种量化方案、数据类型、scale 策略
8.与 PyTorch 原生 Observer 的对比
| 特性 | PyTorch ObserverBase | Horizon ObserverBase |
|---|---|---|
| 量化方案 | 支持非对称量化 | 仅支持对称量化 |
| scale 约束 | 无 | POT/FP16/KPOT 策略 |
| 分布式同步 | 需自行实现 | 内置 sync_minmax |
| 数据类型 | 标准 torch.dtype | 扩展 QuantDType (qint4 等) |
| 版本管理 | 无 | _version 字段支持迁移 |
