蒸馏技术的原理很简单:一个大的"教师"模型回答问题;一个较小的"学生"模型从这些答案以及其背后的概率中学习。你以很小的成本获得接近大模型的行为。在本教程中,我将通过Python和Transformers库带你了解它在实践中是如何运作的。
蒸馏是一种将大模型压缩成更小、更快模型的方法。一个大的"教师"产生输出和软标签;一个较小的"学生"被训练来匹配它们。在实践中,你保留了教师的大部分准确性,同时缩减了规模和成本。教师的概率分布(例如,"正面"为 0.88,"中性"为 0.09,"负面"为 0.03)比单一的硬标签携带更多信息。其他类别上的小概率质量帮助学生学习边界和不确定性,而不仅仅是最高预测。
1、蒸馏在流程中的位置?
基础大语言模型在大量文本上训练,能够生成流畅甚至有创意的输出。但开箱即用,它并不是分类或回归等任务的现成解决方案。微调将基础模型适配到你的任务上,通常只需要几百或几千个示例。标准微调在每一步更新所有参数,代价很高;参数高效微调(如 LoRA、适配器)只调整参数的一个子集,成本更低。无论哪种方式,微调后的模型通常与基础模型具有相同的参数数量。在实践中,许多参数对单一应用来说是无关的。
蒸馏构建了一个保留关键信息的更小模型:它运行更快、成本更低、使用更少资源,代价是少量准确率的折衷。
流程:基础 LLM → 微调(可选)→ 蒸馏 → 更小的可部署模型。
2、为什么要使用蒸馏?
大模型托管成本高、通过 API 调用费用昂贵、运行速度慢、对基础设施要求高。蒸馏解决了这三个问题:学生模型运行成本更低、推理速度更快、足够小以部署在移动端或边缘设备上。你牺牲一点准确率换取规模和速度。当成本、延迟或部署空间成为瓶颈时,这个权衡是值得的。
以下是我们将涵盖的内容:
- 为什么蒸馏?大模型的成本和延迟问题,以及蒸馏何时值得投入。
- 教师-学生蒸馏如何工作。软标签、温度,以及损失函数背后的直觉。
- 逐步实现。使用 Hugging Face Transformers 蒸馏模型的 Python 代码。
- 调优和扩展。温度、损失权重(α)、合成数据以及缓存教师输出。
- 常见陷阱。不平衡数据、教师错误、训练不稳定性以及如何避免它们。
到最后,你将理解整个流程,并拥有可以适配你自己教师和学生模型的可运行代码。让我们开始吧。
3、你需要准备什么
- Python 3.8+并安装 PyTorch 和 Hugging Face Transformers:
pip install torch transformers - GPU对于本演示是可选的(CPU 可以运行这个小型示例);对于使用大教师模型的真实蒸馏,你需要一个。
- 模型:我们将从 Hub 拉取
bert-base-uncased和distilbert-base-uncased(无需本地权重)。
推荐版本(已测试):torch>=2.0、transformers>=4.30。使用pip show torch transformers检查。
4、理解 LLM 蒸馏
在写代码之前,先清楚地理解这个概念很有帮助。把它想象成辅导。你有一个庞大、昂贵的 LLM(教师),你想要一个更小的模型(学生)来完成同样的工作。教师根据你的数据给出答案,并提供一个概率分布,而不仅仅是一个标签。学生学习匹配这个分布。这就是蒸馏。
学生几乎可以是任何你能在同一任务上训练的模型:一个小型 transformer(如 DistilBERT)、一个线性分类器(如逻辑回归),或另一个参数更少的基础模型。在最简单的设置中,你从无标签文本开始,通过教师获得标签或软目标,然后在那个合成标注的数据上训练学生。如果你有一些标注的示例,可以将它们与教师的软标签结合起来(我们的代码正是这样做的)。提示工程(如单样本或少样本提示)可以在不改变教师参数的情况下改善教师的输出;那些更好的输出然后输入到蒸馏中。
在许多设置中,"教师"是一个微调过的基础模型:与原始模型参数数量相同,但适配了你的任务。学生是一个较小的架构(如 DistilBERT 对比 BERT),被训练来模仿教师的输出。较大的模型通常比小模型预测得更好,因此蒸馏模型通常不如教师好,但运行更快且需要更少资源。这种权衡使蒸馏在移动端、边缘设备或成本敏感环境的部署中非常有用。
4.1 知识蒸馏(软目标)
我们在这里做的是经典意义上的知识蒸馏:学生从教师的概率分布(软目标)中学习,而不仅仅是从教师的最终文本或硬标签中学习。该分布比单一答案携带更多信息。像 DistilBERT 这样的模型就是用这种方式创建的:比 BERT 小大约 40%,同时保留了其大部分语言理解能力(在基准测试上通常被引用为约 97%)。当你只能获取教师的文本输出(例如从不返回 logits 的 API)时,你仍然可以通过将教师的答案作为标签来进行一种形式的蒸馏,但你会失去软目标的好处。
重要的部分:软概率携带额外信息。对于一个情感任务,教师可能输出 0.85 正面、0.12 中性、0.03 负面,而不是一个单一的"正面"标签。0.12 和 0.03 告诉学生"中性"和"负面"在某种程度上也是合理的;这是硬标签会隐藏的结构。这种细微差别帮助学生更好地泛化。
一个简单的心理图像:教师处理每个输入并产生一个输出概率分布。我们使用这些软标签来训练学生。教师保持固定;只有学生的权重会更新。在实践中,软目标中的额外信息帮助学生在相同数据上学到更多。
4.2 顶层流程
4.3 软标签 vs 硬标签
下面是教师给你的输出的一个最小示例。我们加载一个预训练模型并对一个句子进行分词。教师返回 logits;我们用softmax将它们转换为概率:
from transformers import AutoModelForSequenceClassification, AutoTokenizer import torch import torch.nn.functional as F # Teacher stays in eval mode; we only use it to generate labels teacher = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased') tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') inputs = tokenizer("Is this product good?", return_tensors="pt") with torch.no_grad(): teacher_logits = teacher(**inputs).logits # These probabilities are the "soft labels" the student will learn to match teacher_probs = F.softmax(teacher_logits, dim=1) print(teacher_probs) # e.g. [0.15, 0.85] for negative / positive对于一个情感任务,你可能会看到类似[0.15, 0.85]的结果:教师主要确信是正面的,但在"负面"上留了一点概率质量。在蒸馏中,学生试图匹配这些数字,而不仅仅是argmax。这种细微差别就是软标签有用的原因。
在实际部署中,你通常运行一次大教师模型(或对一部分数据运行)来生成训练数据,然后训练学生模型并部署它。推理时只使用小模型,因此成本和延迟大幅下降。当你不需要实时答案时,可以使用离线推理(批量或静态推理):批量运行教师(例如每周或每月)来产生标签或缓存预测,然后用学生模型为实时流量服务。繁重的工作只做一次;小模型在关键时刻处理请求。例如,搜索引擎可能会离线使用 LLM 构建一个包含多种语言的大型同义词或类别缓存集,然后使用蒸馏模型或缓存来服务实时查询。
既然基础概念已经清楚,让我们实现完整的训练循环。
5、训练学生模型
这是我们把概念转化为代码的地方。
我们将做三件事:获取教师输出(软标签)、定义学生模型,以及使用蒸馏损失和标准监督损失的混合来训练学生。
下面代码的指导原则:
- 设置:设置设备、随机种子和配置(温度 T、损失权重 α、训练轮数、学习率)。
- 数据:使用一个
Dataset来分词文本并返回input_ids、attention_mask和labels。对于生产环境,将内存中的列表替换为文件或从 Hugging Facedatasets加载。 - 模型:加载具有相同
num_labels的教师和学生(例如二分类情感分析为 2)。冻结教师;只训练学生。 - 训练循环:对于每个批次,获取教师 logits(无梯度)、学生 logits、计算 KL + CE 损失、反向传播,然后裁剪梯度并更新。如果输入重复,可选择缓存教师 logits。
- 评估:训练后,在几个示例上运行两个模型并比较预测;如果你有验证集,计算准确率。
- 保存:使用
student.save_pretrained(path)保存学生和分词器,以便你可以加载用于推理。
我们将使用Hugging Face Transformers。任何兼容的教师和学生都可以工作;这里我们使用bert-base-uncased作为教师、distilbert-base-uncased作为学生,这样你无需大型 GPU 就能运行。对于你自己的设置,替换为你微调过的 BERT、GPT 或其他编码器。选择一个在目标任务上已经表现良好的教师:学生的质量受教师质量的限制,因此强大的教师和有代表性的数据比花哨的损失技巧更重要。
数据选项。你可以在以下数据上训练学生:(1) 标注数据加教师软标签(我们的脚本),(2) 教师标注的无标签数据(合成标签),或 (3) 混合使用。数据集应能代表模型在生产中会看到的任务。数据增强有帮助:让教师生成更多示例(如改写或从其分布中采样),这样学生能看到更广泛的输入。
首先,我们需要数据。在这个演示中,我们将使用一个极小的内存数据集。在生产中,你会将真实输入通过教师运行并存储 logits(或概率),并可选地保留真实标签。损失是一个加权组合:
- 蒸馏损失:学生和教师软输出之间的 KL 散度(使用温度 T)。
- 分类损失:真实标签上的交叉熵(当你有标签时)。
所以:loss = α * distillation_loss + (1 - α) * classification_loss。2015 年 Hinton 等人的蒸馏论文使用了高达 20 的温度来软化教师的分布并暴露更多结构;我们将使用 T=2 作为起点。
5.1 训练流程和损失
import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer # Config: tune T, alpha, lr, max_grad_norm for your task CONFIG = { "teacher_name": "bert-base-uncased", "student_name": "distilbert-base-uncased", "num_labels": 2, "max_length": 64, "batch_size": 8, "epochs": 3, "lr": 2e-5, "T": 2.0, "alpha": 0.5, "max_grad_norm": 1.0, } class TextDataset(Dataset): def __init__(self, texts, labels, tokenizer, max_length=64): self.texts = texts self.labels = labels self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): enc = self.tokenizer( self.texts[idx], truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt' ) item = {k: v.squeeze(0) for k, v in enc.items()} item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long) return item # Load teacher and student (same num_labels; teacher frozen) teacher = AutoModelForSequenceClassification.from_pretrained( CONFIG["teacher_name"], num_labels=CONFIG["num_labels"] ) student = AutoModelForSequenceClassification.from_pretrained( CONFIG["student_name"], num_labels=CONFIG["num_labels"] ) tokenizer = AutoTokenizer.from_pretrained(CONFIG["teacher_name"]) texts = ["I love this product!", "This is terrible.", "Pretty good overall.", "Waste of money.", "Recommend."] labels = [1, 0, 1, 0, 1] # 1 = positive, 0 = negative dataset = TextDataset(texts, labels, tokenizer, CONFIG["max_length"]) dataloader = DataLoader(dataset, batch_size=CONFIG["batch_size"], shuffle=True) def train_distilled(teacher, student, data_loader, config): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") teacher.to(device).eval() student.to(device).train() optimizer = torch.optim.AdamW(student.parameters(), lr=config["lr"]) T, alpha, max_grad_norm = config["T"], config["alpha"], config["max_grad_norm"] for epoch in range(config["epochs"]): for batch in data_loader: inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"} labels_true = batch["labels"].to(device) with torch.no_grad(): teacher_logits = teacher(**inputs).logits / T student_logits = student(**inputs).logits loss_distill = F.kl_div( F.log_softmax(student_logits / T, dim=1), F.softmax(teacher_logits, dim=1), reduction="batchmean", ) * (T * T) loss_ce = F.cross_entropy(student_logits, labels_true) loss = alpha * loss_distill + (1 - alpha) * loss_ce optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(student.parameters(), max_grad_norm) optimizer.step() # Run training train_distilled(teacher, student, dataloader, CONFIG) # Evaluate: compare teacher vs student on same data def evaluate(model, loader, device): model.eval() correct, total = 0, 0 with torch.no_grad(): for batch in loader: inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"} labels = batch["labels"].to(device) preds = model(**inputs).logits.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) return correct / total if total else 0.0 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") teacher.to(device) student.to(device) print(f"Teacher acc: {evaluate(teacher, dataloader, device):.2%}") print(f"Student acc: {evaluate(student, dataloader, device):.2%}") # Save student for deployment student.save_pretrained("./distilled_student") tokenizer.save_pretrained("./distilled_student") print("Student saved to ./distilled_student")5.2 端到端流程
上面的脚本:
- 设置配置,让你可以在一个地方调整T、α、学习率和梯度裁剪;
- 构建一个分词数据集(将
texts/labels替换为你自己的或从datasets加载); - 加载具有相同
num_labels的教师和学生,使 logit 形状匹配; - 使用组合的 KL + CE 损失进行训练并裁剪梯度以保持训练稳定;
- 在相同数据上评估两个模型以进行快速比较;
- 保存学生和分词器,以便你可以使用
AutoModelForSequenceClassification.from_pretrained("./distilled_student")加载进行推理。
从头到尾运行这个代码块;对于完整脚本,将最后一部分包装在if __name__ == "__main__":中,并调用train_distilled→evaluate→save_pretrained。
5.3 快速检查
训练后,在两个模型上运行几个输入并比较logits或预测标签。它们不会完全匹配,但学生的argmax通常应该与教师一致。对于实际项目,在留出的验证集上评估:与教师相比的准确率(或 F1 等)、推理速度和模型大小。如果你在受限硬件上部署,推理期间的资源利用率(CPU、内存、GPU)也很重要。这四个因素(准确率、速度、大小、资源)是你需要平衡的主要权衡。最佳实践是将来自教师的软目标与硬目标(真实标签)结合,这样学生既能模仿教师,又能基于你的任务保持准确。
我们将两个模型都移到设备上,冻结教师,并循环遍历批次。Logits 按1/T缩放以软化分布;我们计算学生和教师 softmax 之间的 KL 散度,并加上真实标签上的交叉熵。T * T因子保持当你改变 T 时梯度规模的一致性。训练后,学生应该在这个任务上的行为与教师相似。
在生产中,你会使用数千个示例(如产品评论、客服工单或你拥有的任何文本)。你也可以使用无标签文本:通过教师运行并使用那些输出作为软标签。(1 - α)项在你有真实标签时让学生基于真实标签保持准确。
最后一件事:我们只更新学生。教师保持冻结,这样它不会漂移,训练保持稳定。让学生保持小巧(DistilBERT、TinyBERT 或你自己的小型变体),这样推理保持快速。接下来我们将添加一些对我有很大帮助的优化。
6、优化和技巧
上面的基本循环可以工作。这些调整使我们的流程更快,学生模型更好。
优化杠杆:
1) 来自教师的合成数据:
如果你缺少标注数据,让教师标注更多。将无标签文本通过教师并存储软(或硬)标签。这给了你额外的训练信号,通常能提高泛化能力。我们为一个客服工单分类器做了这件事,准确率有了明显提升。
2) 温度:
我们使用了 T=2。尝试 5 甚至 10 以获得更软的分布,这样学生能看到更多教师的不确定性:
# Softer targets: more weight on non-argmax classes teacher_logits = teacher(**inputs).logits / 5.0 student_logits = student(**inputs).logits / 5.03) 平衡损失(α):
如果学生对教师过拟合,降低 α(如 0.3)。如果它忽略了教师,提高 α(如 0.7)。在验证集上调优。
4) 缓存教师输出:
如果相同的输入在多个训练轮次中出现(或跨运行),缓存教师 logits,这样你不用重复调用大模型:
cache = {} def get_teacher_logits(teacher, inputs, device): key = tuple(inputs['input_ids'].cpu().flatten().tolist()) if key not in cache: with torch.no_grad(): cache[key] = teacher(**{k: v.to(device) for k, v in inputs.items()}).logits return cache[key]5) 混合精度:
在支持它的 GPU 上,float16 训练(如 PyTorch AMP 或 Hugging Face Accelerate)可以缩短训练时间。对于较大的学生模型或大数据集最有用。
6) 纠正教师的怪癖:
如果学生过于紧密地复制教师,而教师有系统性错误,只在真实标签上微调学生几个轮次(设置 α=0)。这通常可以在不失去蒸馏好处的情况下修复小错误。
7) 进一步探索:
两个进阶想法值得了解。中间层蒸馏训练学生匹配教师的隐藏表示,而不仅仅是最终的 logits;当学生小得多时,这可以减少知识损失。多教师蒸馏让学生从多个教师(如不同的模型或检查点)学习,这样它能看到更广泛的信号。超参数调优至关重要:温度控制教师分布的软化程度(较高的 T 分散概率质量,可以帮助学生学习细微差别),学习率平衡收敛速度和稳定性。在验证集上实验并迭代。
通过这些优化,我们缩短了训练时间,获得了在生产中运行约快 3 倍、每请求成本降低约 70% 的学生模型。蒸馏模型非常适合聊天机器人、情感分析、摘要以及在移动或边缘设备上部署,因为规模和速度很重要。接下来我们看看可能出什么问题以及如何避免。
优化前后对比(概念图):
7、处理常见陷阱
蒸馏很强大,但如果你没准备好,一些陷阱会咬到你。学生不能超越教师:它只是镜像教师学到的东西,所以如果教师在你的任务上表现不佳,学生也会如此。你还需要足够的数据(标注数据或供教师标注的无标签数据)供学生从中学习;蒸馏并没有消除对数据的需求,它改变了你使用数据的方式。
1) 知识损失:
当压缩到更小的学生时,教师的一些细微差别会丢失。需要深度或专业知识的任务可能会出现明显的性能下降。缓解方法:使用中间层蒸馏让学生看到教师的内部表示,添加数据增强让学生看到更多样化的示例,或者运行多轮蒸馏(训练一个学生,然后再从那个学生蒸馏)来逐步改进小模型。
2) 不平衡或噪声数据:
当一个类别占主导时,学生可能会忽略稀有类别。在交叉熵项中使用类别权重,让学生更关注少数标签:
# If class 1 is rare, upweight it (use your batch labels variable) class_weights = torch.tensor([1.0, 5.0]).to(device) loss_ce = F.cross_entropy(student_logits, labels, weight=class_weights)你也可以对少数类过采样或平衡批次。
3) 教师错误:
学生不能超越教师的输出。如果教师确信自己是错的,学生会学到那个错误。保持一个可靠的(1 - α)项让真实标签仍然重要,或者过滤掉低置信度的教师预测。对于关键领域,我们有时会添加基于规则的检查或对样本进行小规模人工审核。记住:在最简单的设置中,学生镜像教师的表现,所以生产级别的准确性通常意味着首先改进教师(或数据)。
4) 训练不稳定:
大的蒸馏损失或高学习率可能导致梯度爆炸或 NaN。裁剪梯度并使用适中的学习率(通常比正常微调更小)。蒸馏通常需要比纯微调更低的学习率,因为 KL 项改变了损失 landscape。在验证集上一起调整温度和学习率;它们强烈影响学生从教师学习的效果。
torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)5) 架构或分词器不匹配:
确保学生的分词器和词表与教师的使用方式一致(如相同的特殊标记、相似的词表大小)。如果教师是生成式的而学生是分类器,你需要自定义损失。始终在留出集上评估。
6) 数据多样性不足:
如果学生只看到非常相似的输入,它不会很好地泛化。使用教师标注多样化的文本,并考虑简单的增强(如改写、同义词)来扩大覆盖范围。在一个项目中,我们用同义词增强用户查询,看到了对改写更好的鲁棒性。
7) 何时使用蒸馏:
当推理成本、延迟或部署空间是瓶颈时使用:高流量 API、移动或边缘设备,或计算资源有限的环境。当你不需要实时答案时使用离线推理:定期运行教师刷新标签或缓存,然后用学生模型或缓存结果服务。蒸馏在教师已经很小、你没有无标签数据且标签很少、或你的任务需要教师的全部能力且不能承受任何准确率损失时不太有用。
8) 应用场景:
蒸馏后的 LLM 用于聊天机器人和虚拟助手、文本摘要、机器翻译、情感分析和问答系统。在工业界,它们出现在医疗保健(如处理临床文本)、金融(欺诈检测、客户支持)和教育(辅导、评分)领域。共同的主线是需要速度和效率,同时不放弃太多准确性。
一旦你处理了这些问题,你就拥有了一个可以在生产中信任的蒸馏流程。让我们总结一下要点。
8、结束语
我们从一个庞大、昂贵的 LLM 开始,使用软标签和蒸馏与监督损失的混合训练了一个较小的学生来模仿它。以下是需要记住的:
- 教师-学生设置。大模型(教师)产生软标签;小模型(学生)被训练来匹配它们。软概率比硬标签携带更多信息,帮助学生更好地泛化。
- 损失和温度。将教师软输出上的 KL 散度与真实标签上的交叉熵结合。温度 T 软化分布;在验证集上调优 T 和 α。
- 影响。蒸馏模型可以运行得更快、更便宜。在我们的案例中,我们看到推理速度约快 3 倍,每请求成本降低约 70%。
- 实际改进。使用教师的合成数据、缓存教师 logits、尝试混合精度,以及可选地用仅真实标签的短期微调来纠正。
- 陷阱。注意类别不平衡、教师错误、训练不稳定以及分词器/架构不匹配。使用类别权重、梯度裁剪和验证来保持安全。学生受教师和你拥有的数据限制;更多参数通常意味着更好的预测,蒸馏用一些预测能力换取规模和速度。通过中间层蒸馏、数据增强或迭代蒸馏来缓解知识损失。
最佳实践:
实验不同的温度、α 和学习率;在验证集上持续评估胜过一次性调优。保持训练在软目标(来自教师)和硬目标(真实标签)之间平衡。保持对蒸馏研究的关注(如生成模型的逐步蒸馏或上下文蒸馏等技术),并复用适合你设置的方法。
蒸馏不是一键解决方案:你必须选择学生、数据和超参数。但当它成功时,感觉很好。我们缩小了模型并保持了接近教师的准确性,同时降低了成本和延迟。帮助 DeepSeek 和 Moonshot 等实验室在成本上竞争的同样想法也可以在你自己的技术栈中工作。
负责任地使用蒸馏(和基础)模型:它们可能反映训练或蒸馏数据中的偏见。公平、透明的使用以及在你自己数据上的评估与准确性和速度同样重要。
原文链接:三步蒸馏大语言模型 - 汇智网