最近有很多朋友都在部署deepseek模型,而且都用到了模型量化这个功能,目的是减少显存占用、提升推理速度。
量化算法流程
上图是w8a8量化算法流程,主要包含4步:
①,使用昇腾 msmodelslim 仓库提供的量化接口对原始模型权重进行量化,生成int8格式的权重文件,以及后续在推理的时候要用到的激活值的量化参数和 matmul 结果的反量化参数;
②,推理执行过程中,把Matmul的激活值(也就是输入X)进行int8量化;
③,执行int8格式的Matmul计算;
④,把int8的乘法结果进行反量化。
这篇文章讲解第①步的内容。msmodelslim提供的deepseek模型量化的参考脚本的链接如下:
Ascend/msitgitee.com/ascend/msit/tree/br_noncom_MindStudio_8.0.0_POC_20251231/msmodelslim/example/DeepSeek编辑
入口脚本 quant_deepseek_w8a8.py 的代码内容如下(
br_noncom_MindStudio_8.0.0_POC_20251231分支,commit 06a6e8):
#Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import argparse
import functools
import json
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from msmodelslim.tools.convert_fp8_to_bf16 import auto_convert_model_fp8_to_bf16, OpsType
from msmodelslim.tools.copy_config_files import copy_config_files, modify_config_json
from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlierConfig, AntiOutlier
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator, QuantConfig
from msmodelslim.tools.logger import set_logger_level
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help="model and tokenizer path"),
parser.add_argument('--save_path', type=str, help="save path"),
parser.add_argument('--layer_count', type=int, default=0)
parser.add_argument('--anti_dataset', type=str, default="./anti_prompt.json")
parser.add_argument('--calib_dataset', type=str, default="./calib_prompt.json")
parser.add_argument('--fp8', action='store_true')
parser.add_argument('--bf16', action='store_true')
return parser.parse_args()
def custom_hook(model_config):
model_config["mla_quantize"] = "w8a8"
args = parse_args()
set_logger_level("warning")
pbar = tqdm(total=4, position=0, desc="Total Process")
model_path = args.model_path
config = AutoConfig.from_pretrained(pretrained_model_name_or_path=model_path, trust_remote_code=True)
config.num_hidden_layers = args.layer_count if args.layer_count != 0 else config.num_hidden_layers
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_path,
config=config,
trust_remote_code=True,
use_fast=True,
add_eos_token=True)
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_path,
config=config,
trust_remote_code=True,
device_map="auto",
torch_dtype="auto",
max_memory={
0: "50GiB",
"cpu": "1500GiB"
},
attn_implementation='eager')
auto_convert_model_fp8_to_bf16(model, model_path, OpsType.get_ops_type(args.bf16, args.fp8))
pbar.update(1)
def get_anti_dataset(tokenizer, calib_list, device="npu"):
calib_dataset = []
max_len = 0
for calib_data in calib_list:
inputs = tokenizer(calib_data, return_tensors='pt')
calib_dataset.append(inputs.data['input_ids'].to(device))
max_len = max(max_len, inputs.data['input_ids'].size(1))
for i in range(len(calib_dataset)):
calib_dataset[i] = F.pad(calib_dataset[i], (0, max_len - calib_dataset[i].size(1)), value=0)
return torch.cat(calib_dataset)
def get_calib_dataset(tokenizer, calib_list, device="npu"):
calib_dataset = []
for calib_data in calib_list:
inputs = tokenizer(calib_data, return_tensors='pt').to(device)
calib_dataset.append([inputs.data['input_ids']])
return calib_dataset
with open(args.anti_dataset, "r") as file:
anti_prompt = json.load(file)
with open(args.calib_dataset, "r") as file:
calib_prompt = json.load(file)
anti_data = []
for i in range(len(anti_prompt)):
tmp = get_anti_dataset(tokenizer, anti_prompt[i])
anti_data.append(tmp)
anti_dataset = []
for data in anti_data:
anti_dataset.append([data])
dataset_calib = []
for i in range(len(calib_prompt)):
tmp = get_calib_dataset(tokenizer,calib_prompt[i])
dataset_calib += (tmp)
with torch.no_grad():
anti_config = AntiOutlierConfig(w_bit=8,
a_bit=8,
anti_method='m4',
dev_type='npu',
dev_id=model.device.index)
anti_outlier = AntiOutlier(model, calib_data=anti_dataset, cfg=anti_config)
anti_outlier.process()
pbar.update(1)
disable_names = []
for ids in range(config.num_hidden_layers):
disable_names.append("model.layers." + str(ids) + ".self_attn.kv_b_proj")
quant_config = QuantConfig(
a_bit=8,
w_bit=8,
disable_names=disable_names,
dev_type='npu',
dev_id=model.device.index,
act_method=1,
pr=1.0,
w_sym=True,
mm_tensor=False,
is_dynamic=True
)
calibrator = Calibrator(model, quant_config, calib_data=dataset_calib, disable_level="L0")
calibrator.run()
pbar.update(1)
calibrator.save(args.save_path, save_type=["safe_tensor"], part_file_size=4)
custom_hooks = {
'config.json': functools.partial(modify_config_json, custom_hook=custom_hook)
}
copy_config_files(input_path=args.model_path, output_path=args.save_path, quant_config=quant_config, custom_hooks=custom_hooks)
pbar.update(1)
这篇文章会从入口脚本出发,对w8a8量化的技术原理和代码进行解析。
1. 算法原理
上面的代码涉及到2个类:AntiOutlier 和 Calibrator,AntiOutlier 代表的是激活异常值抑制,Calibrator 是激活值和权重量化。
1.1 激活异常值抑制
对于int8量化算法,浮点数量化后的取值是有限的(-128、-127、...、127),所以浮点数的分布范围越广的话,量化步长就越大,那么就有更多的浮点数会被量化成同一个数值,也就会引入更大的误差。而且大家发现,对于大模型里面的Matmul,激活值X和权重W的浮点数分布很不相同,X的分布范围更大,这样就会导致激活值X的量化误差较大。
为了解决激活值的量化误差问题,有人提出了 smoothQuant 算法。这个算法的原理很简单:让X除以一个值s,W乘以s,这样的话,X/s和W/s的分布就会更加平滑,同时(X/s)*(W*s)=X*W,保证乘积不变。s的计算公式如下:
其中 Xj 代表 X 的第 j 列, Wj 代表 W 的第 j 行, α 一般取0.5。示例如下:
在代码实现中,因为X是norm层的输出,所以会把X除以s的操作转移到norm层,让norm层的权重除以s,这样就不用在推理的过程中,再做一个除法。
1.2 w8a8量化
这部分没什么好说的,quant_deepseek_w8a8.py 中用到的就是 min-max 量化算法。对于一个权重tensor或者激活X的tensor来说,假如它的最大、最小值分别为max、min,那么首先可以求出 scale=(max-min)/255,然后得出量化公式为
x是tensor的每个元素。当然,除了以tensor粒度求解scale,还有以tensor的每个通道值分布求scale的,我们称作per-channel。
需要注意的是,权重在推理之前就是已知的,所以不需要做模型推理、直接对权重文件的数据进行量化即可;但是激活值X是在推理的时候才能获取的,所以我们还需要准备一些“校准数据集”,让模型做一些前向推理,以此确定激活值的量化参数。
算法理论部分到这里就结束了,比较简单,接下来我们看看代码层面是如何实现的。
2. 代码解析
量化入口脚本功能
上图是deepseek w8a8量化入口脚本,主要包含3个部分:异常值抑制、w8a8量化、保存量化权重和相关参数。
2.1 异常值抑制
anti_config包含的参数如下:
anti_config = AntiOutlierConfig(w_bit=8,
a_bit=8,
anti_method='m4',
dev_type='npu',
dev_id=model.device.index)
其中w_bit和a_bit代表权重量化位数和激活值量化位数,anti_method代表抑制算法,'m4'是 smooth_quant(m1) 的改进方法,相比于 smooth_quant 增加了量化层。dev_type和dev_id代表运行异常值抑制使用的设备。
anti_outlier的核心代码在 msmodelslim\msmodelslim\pytorch\llm_ptq\anti_outlier\anti_outlier.py。init()函数和process()函数的核心逻辑流程图如下:
步骤1 是在AntiOutlier类的init()函数中完成的,后续步骤是在process()函数中完成的。
初始化有向无环图是在init()函数的这个部分执行的:
try:
self.init_dag()
except Exception as e:
raise Exception("Please check your config, model and input!", e) from e
对于attention模型,构建DAG图的过程就是找出RMSNorm算子和它们连接的linear层。对于抑制算法"m1",我们做w8a8量化的目标层是qkv乘法和up、gate的全连接层;对于抑制算法"m4",还包含了O层和down层(O层的激活值scale转移到V层,down层的激活值scale转移到up层)。
init_dag()函数核心代码如下:
if self.norm_class_name is not None: # 可以手动指定norm层
norm_class = list(OrderedDict.fromkeys([m.__class__ for m in self.model.modules() if
self.norm_class_name.lower() == m.__class__.__name__.lower()]))
else:
# 查找包含“norm”字段的层
norm_class = list(
OrderedDict.fromkeys(
[m.__class__ for m in self.model.modules() if "norm" in m.__class__.__name__.lower()]))
norm_class = [norm_class[0]]
self.norm_class_name = norm_class[0].__name__.lower()
if ProcessHook.GET_NORM_LINEAR_SUBGRAPH not in self.hooks or self.hooks[
ProcessHook.GET_NORM_LINEAR_SUBGRAPH] is None:
# 调用extract_dag()获取DAG图
dag = extract_dag(self.model, dummy_input,
hook_nodes=norm_class, anti_method=self.cfg.anti_method)
self.norm_linear_subgraph = dag.get_norm_linear_subgraph()
if self.cfg.anti_method == 'm4':
self.linear_linear_subgraph = dag.get_linear_linear_subgraph()
self.norm_linear_subgraph.update(self.linear_linear_subgraph)
del dag
上面的代码中主要调用了 extract_dag() 函数获取DAG图,然后得到norm_linear_subgraph。extract_dag调用的又是TorchDAGAdapter类,这篇文章不做详细分析。norm_linear_subgraph 的格式如下所示:
norm_linear_subgraph{'model.layers0.input_layernorm': ['model.layers0.attn.q_proj', 'model.layers0.attn.k_proj', 'model.layers0.attn.j_proj'], 'model.layers0.post_attention_layernorm': ['model.layers0.mlp.gate_proj', 'model.layers0.mlp.up_proj'], ...}
anti_outlier的process的核心代码如下:
def _process(self):
...
# 给模型层注册hook,执行推理,记录每层的输入输出
act_stats = self.os_stats()
....
# 遍历需要做量化的层
for norm_name_group in tqdm(iterable=self.norm_linear_subgraph.keys(), desc="AntiOutlier Process", position=1):
linear_names = self.norm_linear_subgraph[norm_name_group]
if isinstance(norm_name_group, str):
norm_module = PatternProcess.get_module_by_name(self.model, norm_name_group)
...
stats = act_stats[linear_name]
is_expert = any("expert" in name.lower() for name in linear_names)
if (is_expert):
continue
self.logger.debug(f"smooth {norm_name_group} -> {linear_names}")
for name in linear_names:
mod = PatternProcess.get_module_by_name(self.model, name)
linear_modules.append(mod)
...
# 对权重进行smooth
if Multiplier is not None and norm_module is None:
norm_module = Multiplier(
torch.ones_like(stats[STAT_KEY_SMOOTH_SCALE]).to(linear_modules[0].weight.device)
)
prepare_list = [PrepareWeight(norm_module, post_force=True, post_recurse=True)]
prepare_list += [PrepareWeight(mod, post_force=True) for mod in linear_modules]
# 对norm层权重进行smooth
with ResListToRelease(*prepare_list):
if self.cfg.anti_method == 'm1' or self.cfg.anti_method == 'm5':
smooth_ln_fcs(self.cfg, norm_module, linear_modules, stats, alpha=self.cfg.alpha)
elif self.cfg.anti_method == 'm2':
os_ln_fcs(self.cfg, norm_module, linear_modules, stats, os_k=self.cfg.os_k)
elif self.cfg.anti_method == 'm3':
weight_aware(self.cfg, norm_module, linear_modules, stats)
elif self.cfg.anti_method == 'm4':
if 'scale_min' in inspect.signature(iter_smooth).parameters:
fusion_kwargs.update({"scale_min": scale_min})
if 'check_group_fusions' not in inspect.signature(iter_smooth).parameters:
fusion_kwargs.pop("check_group_fusions", None)
self.logger.debug(f"fusion_kwargs is {fusion_kwargs}")
iter_smooth(
self.cfg, norm_module, linear_modules, stats, num_attention_heads, **fusion_kwargs
)
if attach_op is not None and Multiplier is not None and isinstance(norm_module, Multiplier):
attach_op(self.model, norm_module, linear_modules, linear_names)
上面的代码,首先执行self.os_stats(),这个函数的功能是在模型层上注册hook,然后使用校准数据进行推理,收集每层的输入输出。
然后遍历 norm_linear_subgraph 的值,把norm层和对应的linear层找出来,先把权重乘以s,然后使用CANN包路径
/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/msmodelslim 下的so里面的方法做激活层的smooth,也就是对norm层的权重进行处理。
以上就是异常值处理的主要逻辑,完成异常值处理后,model里面的norm层和linear层的权重已经发生了变化,model会继续传给后续的calibrator处理。
2.2 w8a8量化代码
首先需要设置量化方法的参数:
quant_config = QuantConfig(
a_bit=8,
w_bit=8,
disable_names=disable_names,
dev_type='npu',
dev_id=model.device.index,
act_method=1,
pr=1.0,
w_sym=True,
mm_tensor=False,
is_dynamic=True
)
a_bit和w_bit代表量化bit数;disable_names代表不做量化层的名称;act_method代表激活值的量化方法,“1”代表min-max;pr是概率参数,非1时量化生成的参数带有随机性;w_sym是指权重是否做对称量化;mm_tensor=False代表权重是per-channel量化;is_dynamic=True代表激活量化使用动态量化,也就是量化参数是在推理的时候生成,is_dynamic=False代表静态量化,在调用calibrator量化权重的时候就把激活值的量化参数计算好,动态量化精度更高,但是性能更差。
再来看一下实例化calibrator:
calibrator = Calibrator(model, quant_config, calib_data=dataset_calib, disable_level="L0")
传入了异常值抑制后的model、quant_config、校准数据集和disable_level。disable_level='“Ln”代表模型结构从最后一层往前数的n层不做量化。校准数据集一般采用模型实际应用场景的数据,而且在调试精度的时候,如果发现量化模型在某条数据上精度较差,可以把该条数据加入校准数据集,再进行校准量化。
再来看一下init()函数做了哪些事情:
def __init__(self, model,
cfg: QuantConfig,
calib_data=None,
disable_level='L0',
all_tensors=None):
...
# 获取校准数据集
self.calib_data = self.get_calib_data([]) if calib_data is None else self.get_calib_data(calib_data)
self.use_kvcache_quant = cfg.use_kvcache_quant # false
self.norm_class_name = cfg.norm_class_name
...
# 创建字典记录量化参数
self.quant_param_dict = AutoSaveDict(self.cfg, max_gb_size=1)
# 记录被量化module名称,相关的scale、offset等参数名称 key:weight的名称, value:scale、offset等参数的名称
self.quantized_module_param_dict = defaultdict(list)
self.fa_module_param_dict = defaultdict(list)
...
# 初始化模型权重json描述
self.quant_model_json_description = QuantModelJsonDescription(self.cfg.model_quant_type,
self.cfg.use_kvcache_quant,
self.cfg.use_fa_quant)
if not re.match(r'^L((?!0)\d+|0)