7.7亿参数,超越5400亿PaLM!UW谷歌提出「分步蒸馏」,只需80%训练数据|ACL 2023
来源:51CTO.COM
时间:2023-10-11 10:00:17 180浏览 收藏
在IT行业这个发展更新速度很快的行业,只有不停止的学习,才不会被行业所淘汰。如果你是科技周边学习者,那么本文《7.7亿参数,超越5400亿PaLM!UW谷歌提出「分步蒸馏」,只需80%训练数据|ACL 2023》就很适合你!本篇内容主要包括##content_title##,希望对大家的知识积累有所帮助,助力实战开发!
大型语言模型在性能方面表现出色,能够通过零样本或少样本提示来解决新任务。然而,在实际应用部署中,LLM却不太实用,因为它的内存利用效率低,同时需要大量的计算资源
比如运行一个1750亿参数的语言模型服务至少需要350GB的显存,而目前最先进的语言模型大多已超过5000亿参数量,很多研究团队都没有足够的资源来运行,在现实应用中也无法满足低延迟性能。
也有一些研究使用人工标注数据或使用LLM生成的标签进行蒸馏来训练较小的、任务专用的模型,不过微调和蒸馏需要大量的训练数据才能实现与LLM相当的性能。
为了解决大型模型对资源的需求问题,华盛顿大学与谷歌合作提出了一种名为「分步蒸馏」(Distilling Step-by-Step)的新蒸馏机制。通过分步蒸馏,经过蒸馏后的模型尺寸相较于原模型来说更小,但性能却更优,而且在微调和蒸馏过程中所需的训练数据也更少
请点击以下链接查看论文:https://arxiv.org/abs/2305.02301
分布蒸馏机制把LLM中抽取出的预测理由(rationale)作为在多任务框架内训练小模型的额外监督信息。
经过在4个NLP基准上进行实验后,我们发现:
1. 与微调和蒸馏相比,该机制用更少的训练样本实现了更好的性能;
相较于少样本提示LLM,该机制利用更小尺寸的模型实现了更出色的性能
3. 同时降低模型尺寸和数据量也可以实现优于LLM的性能。
实验中,微调后770M的T5模型在基准测试中仅使用80%的可用数据就优于少样本提示的540B的PaLM模型,而标准微调相同的T5模型即使使用100%的数据集也难以匹配。
蒸馏方法
分布蒸馏的关键思想是逐步抽取出信息丰富且用自然语言描述的预测理由,即中间推理步骤,以解释输入问题与模型输出之间的联系,并通过这些数据来更高效地训练小模型
分布蒸馏主要包括两个阶段:
1. 从LLM中提取原理(rationale)
研究人员利用少样本思维链(CoT)提示从LLM中提取预测中间步骤。
在确定目标任务之后,首先在LLM输入提示中准备几个样例。每个样例都由一个三元组组成,包括输入、原理和输出
输入提示后,LLM能够模仿三元组演示以生成其他新问题的预测原理,例如,在常识问答案任务中,给定输入问题:
Sammy想去人群聚集的地方。他会选择哪里呢?选项有:(a)人口稠密地区,(b)赛道,(c)沙漠,(d)公寓,(e)路障
(Sammy wanted to go to where the people are. Where might he go? Answer Choices: (a) populated areas, (b) race track, (c) desert, (d) apartment, (e) roadblock)
通过逐步提炼后,LLM可以给出问题的正确答案「(a)人口稠密地区」,并且提供回答问题的理由「答案必须是一个有很多人的地方,在上述选择中,只有人口稠密的地区有很多人。」 经过逐步提炼,LLM能够得出正确答案为「(a)人口稠密地区」,并提供了解答问题的理由「答案必须是一个有很多人的地方,在上述选择中,只有人口稠密的地区有很多人。」
通过在提示中提供与基本原理配对的CoT示例,上下文学习能力可以让LLM为未曾遇到的问题类型生成相应的回答理由
2. 训练小模型
通过将训练过程构建为多任务问题,可以将预测理由抽取出来,并将其纳入训练小模型中
除了标准标签预测任务之外,研究人员还使用新的理由生成任务来训练小模型,使得模型能够学习生成用于预测的中间推理步骤,并且引导模型更好地预测结果标签。
通过在输入提示中加入任务前缀「label」和「rationale」来区分标签预测和理由生成任务。
实验结果
在实验中,研究人员选择5400亿参数量的PaLM模型作为LLM基线,使用T5模型作为任务相关的下游小模型。
在这项研究中,我们对四个基准数据集进行了实验,这四个数据集分别是e-SNLI和ANLI用于自然语言推理,CQA用于常识问答,以及SVAMP用于算术数学应用题。我们在这三个不同的NLP任务中进行了实验
更少的训练数据
分步蒸馏方法在性能上比标准微调更出色,而且只需较少的训练数据
在e-SNLI数据集上,当使用完整数据集的12.5%时就实现了比标准微调更好的性能,在ANLI、CQA和SVAMP上分别只需要75%、25%和20%的训练数据。
与使用220M T5模型对不同大小的人工标记数据集进行标准微调相比,分布蒸馏在所有数据集上使用更少的训练示例时,优于在完整数据集上进行标准微调
更小的部署模型尺寸
与少样本CoT提示的LLM相比,分布蒸馏得到的模型尺寸要小得多,但性能却更好。
在e-SNLI数据集上,使用220M的T5模型实现了比540B的PaLM更好的性能;在ANLI上,使用770M的T5模型实现了比540B的PaLM更好的性能,模型尺寸仅为1/700
更小的模型、更少的数据
在减小模型尺寸和训练数据的同时,我们成功地实现了超越少样本PaLM的性能
在ANLI中,使用770M T5模型的性能超过了540B PaLM,而且只使用了完整数据集的80%
经观察可知,即使使用完整的100%数据集,标准微调也无法达到PaLM的性能水平,这表明通过分步蒸馏可以同时减小模型尺寸和训练数据量,从而实现超越LLM的性能
好了,本文到此结束,带大家了解了《7.7亿参数,超越5400亿PaLM!UW谷歌提出「分步蒸馏」,只需80%训练数据|ACL 2023》,希望本文对你有所帮助!关注golang学习网公众号,给大家分享更多科技周边知识!
-
501 收藏
-
501 收藏
-
501 收藏
-
501 收藏
-
501 收藏
-
234 收藏
-
465 收藏
-
100 收藏
-
307 收藏
-
280 收藏
-
121 收藏
-
194 收藏
-
417 收藏
-
430 收藏
-
315 收藏
-
319 收藏
-
170 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 立即学习 542次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 立即学习 508次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 立即学习 497次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 立即学习 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 立即学习 484次学习