Llama 3+与Mamba强强联合!通过蒸馏到线性RNN,推理速度提升1.6倍。
编辑日期:2024年09月11日
Mamba 的提出者 Tri Dao 参与其中
将 Llama 3 蒸馏至 Mamba,推理速度最高可提升 1.6 倍!
不仅如此,性能不降反升,甚至优于原始模型。
这是来自 Together AI 的最新成果,通过蒸馏技术将 Transformer 和 Mamba 模型结合起来,并设计了专门的推理加速算法。
Mamba 架构的提出者、FlashAttention 的作者 Tri Dao 也参与了该项目。
Together AI 的创始人兼 CEO 表示,Transformer 与 Mamba 的结合代表了未来大规模模型的发展方向。
在正式蒸馏之前,需要先进行从 Transformer 到线性 RNN 的初始化。
作者发现 Transformer 的注意力机制与 RNN 的计算之间存在一定的相似性。
因此,可以通过线性化 Transformer 的注意力机制来建立两者的联系。
利用这种对应关系,可以将预训练的 Transformer 模型参数复制到 Mamba 模型中。
在完成参数初始化后,作者采用了三阶段的蒸馏流程以进一步提升Mamba模型的性能,帮助其更好地学习Transformer的知识。第一阶段是基于伪标签的蒸馏:利用预训练的Transformer教师模型为无标签数据生成伪标签,再用这些伪标签训练Mamba学生模型。
此过程中的损失函数综合了KL散度损失和交叉熵损失,前者用于模仿教师模型的输出分布,后者则用于拟合伪标签。
第二阶段是对指令数据集进行监督微调,即利用带有标签的指令数据集(例如OpenHermes 2.5)进行训练。
第三阶段则是利用人类反馈数据进行优化。作者收集了人们对模型输出的反馈,并据此构建了一个奖励模型,进而使用强化学习算法(如PPO)来优化模型在该奖励模型下的表现。
整个蒸馏过程在一个混合模型中进行,在配备8块80GB A100 GPU的环境下,耗时不超过五天。
通过上述蒸馏流程,作者最终获得了Transformer-Mamba混合模型,并提出了一种名为“推测解码”的算法以加速推理过程。推测解码的核心思想是使用一个轻量级的Draft模型来预测多个token,再由验证模型(Verifier)来确认这些预测结果。
这种方法显著提高了解码过程的并行性,加快了生成速度。Draft模型通常是一个较小的Transformer,能够根据当前上下文预测出接下来的K个token。对于这K个token,Transformer层可以直接并行处理,计算其隐状态;而Mamba层则需按顺序逐个处理每个token,先计算当前token的隐状态,并与之前的隐状态进行对比。
如果序列中的所有K个token都被接受,则将它们添加到输出序列中,并继续预测下一组token。如果有token被拒绝,则从第一个被拒绝的token处截断预测序列,并返回初始步骤从该位置开始重新预测。
测试结果显示,在单轮(AlpacaEval)和多轮(MT-Bench)聊天对话任务上,混合模型的表现与Llama-3持平甚至更优。
此外,还测试了不同混合比例下的模型表现,发现1:1的比例时模型表现最佳。
在零样本的通用NLP任务评测中,混合模型的平均成绩优于同等规模的RNN模型。
在少样本的OpenLLM Leaderboard榜单上,混合模型的表现与最优的开源RNN模型相当,并在GSM8K和CRUX任务上超过了对应的Instruct模型。
除了评估模型性能外,作者还测试了推测解码算法带来的加速效果。首先是在纯Mamba模型上的测试,结果表明在2.8B和7B模型上,新解码方式使推理速度提升了1.7至2.6倍。
进一步地,在蒸馏的Zephyr和Llama混合模型上进行测试,结果表明Zephyr混合模型的推理速度提升了1.8倍以上,而Llama混合模型也有大约1.6倍的加速效果。
Llama 3+与Mamba强强联合!
论文地址:https://www.together.ai/blog/the-mamba-in-the-llama-distilling-and-accelerating-hybrid-models
不仅如此,这不仅仅是大模型公司的合作。
量力而行,“小心AGI觉醒后要求欺诈补偿”。
参数越大,模型越“聪明”。
GPT-4等九大LLM“无一幸免”。
系列模型中的三个将被开源。
吞吐量相比FP16提升了2.65倍。