Redian新闻
>
如何从头开始编写LoRA代码,这有一份教程

如何从头开始编写LoRA代码,这有一份教程

公众号新闻

选自 lightning.ai

作者:Sebastian Raschka

机器之心编译

编辑:陈萍

作者表示:在各种有效的 LLM 微调方法中,LoRA 仍然是他的首选。

LoRA(Low-Rank Adaptation)作为一种用于微调 LLM(大语言模型)的流行技术,最初由来自微软的研究人员在论文《 LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS 》中提出。不同于其他技术,LoRA 不是调整神经网络的所有参数,而是专注于更新一小部分低秩矩阵,从而大大减少了训练模型所需的计算量。


由于 LoRA 的微调质量与全模型微调相当,很多人将这种方法称之为微调神器。自发布以来,相信很多人都对这项技术感到好奇,想要从头开始编写代码从而更好的理解该研究。以前苦于没有合适的文档说明,现在,教程来了。


这篇教程的作者是知名机器学习与 AI 研究者 Sebastian Raschka,他表示在各种有效的 LLM 微调方法中,LoRA 仍然是自己的首选。为此,Sebastian 专门写了一篇博客《Code LoRA From Scratch》,从头开始构建 LoRA,在他看来,这是一种很好的学习方法。



简单来说,本文通过从头编写代码的方式来介绍低秩自适应(LoRA),实验中 Sebastian 对 DistilBERT 模型进行了微调,并用于分类任务。


LoRA 与传统微调方法的对比结果显示,使用 LoRA 方法在测试准确率上达到了 92.39%,这与仅微调模型最后几层相比(86.22% 的测试准确率)显示了更好的性能。


Sebastian 是如何实现的,我们接着往下看。


从头开始编写 LoRA


用代码的方式表述一个 LoRA 层是这样的:



其中,in_dim 是想要使用 LoRA 修改的层的输入维度,与此对应的 out_dim 是层的输出维度。代码中还添加了一个超参数即缩放因子 alpha,alpha 值越高意味着对模型行为的调整越大,值越低则相反。此外,本文使用随机分布中的较小值来初始化矩阵 A,并用零初始化矩阵 B。


值得一提的是,LoRA 发挥作用的地方通常是神经网络的线性(前馈)层。举例来说,对于一个简单的 PyTorch 模型或具有两个线性层的模块(例如,这可能是 Transformer 块的前馈模块),其前馈(forward)方法可以表述为:



在使用 LoRA 时,通常会将 LoRA 更新添加到这些线性层的输出中,又得到代码如下:




如果你想通过修改现有 PyTorch 模型来实现 LoRA ,一种简单方法是将每个线性层替换为 LinearWithLoRA 层:



以上这些概念总结如下图所示:


为了应用 LoRA,本文将神经网络中现有的线性层替换为结合了原始线性层和 LoRALayer 的 LinearWithLoRA 层。


如何上手使用 LoRA 进行微调


LoRA 可用于 GPT 或图像生成等模型。为了简单说明,本文采用一个用于文本分类的小型 BERT(DistilBERT) 模型来说明。



由于本文只训练新的 LoRA 权重,因而需要将所有可训练参数的 requires_grad 设置为 False 来冻结所有模型参数:




接下来,使用 print (model) 检查一下模型的结构:



由输出可知,该模型由 6 个 transformer 层组成,其中包含线性层:



此外,该模型有两个线性输出层:



通过定义以下赋值函数和循环,可以选择性地为这些线性层启用 LoRA:



使用 print (model) 再次检查模型,以检查其更新的结构:



正如上面看到的,线性层已成功地被 LinearWithLoRA 层取代。


如果使用上面显示的默认超参数来训练模型,则会在 IMDb 电影评论分类数据集上产生以下性能:


  • 训练准确率:92.15%

  • 验证准确率:89.98%

  • 测试准确率:89.44%


在下一节中,本文将这些 LoRA 微调结果与传统微调结果进行了比较。


与传统微调方法的比较


在上一节中,LoRA 在默认设置下获得了 89.44% 的测试准确率,这与传统的微调方法相比如何?


为了进行比较,本文又进行了一项实验,以训练 DistilBERT 模型为例,但在训练期间仅更新最后 2 层。研究者通过冻结所有模型权重,然后解冻两个线性输出层来实现这一点:



只训练最后两层得到的分类性能如下:


  • 训练准确率:86.68%

  • 验证准确率:87.26%

  • 测试准确率:86.22%


结果显示,LoRA 的表现优于传统微调最后两层的方法,但它使用的参数却少了 4 倍。微调所有层需要更新的参数比 LoRA 设置多 450 倍,但测试准确率只提高了 2%。


优化 LoRA 配置


前面讲到的结果都是 LoRA 在默认设置下进行的,超参数如下:



假如用户想要尝试不同的超参数配置,可以使用如下命令:



不过,最佳超参数配置如下:



在这种配置下,得到结果:


  • 验证准确率:92.96%

  • 测试准确率:92.39%


值得注意的是,即使 LoRA 设置中只有一小部分可训练参数(500k VS 66M),但准确率还是略高于通过完全微调获得的准确率。


原文链接:https://lightning.ai/lightning-ai/studios/code-lora-from-scratch?continueFlag=f5fc72b1f6eeeaf74b648b2aa8aaf8b6





© THE END 

转载请联系本公众号获得授权

投稿或寻求报道:[email protected]

微信扫码关注该文公众号作者

戳这里提交新闻线索和高质量文章给我们。
相关阅读
黄暴高能,这部18禁复仇剧从头爽到尾使用 Kate 编写文档 | Linux 中国Rust 编写的 Zed 编辑器开源:约 27 万行代码、主打“高性能”全面推行AI写代码,阿里云未来20%代码由通义灵码编写;阿尔特曼被取消OpenAI风投部门控制权丨AIGC日报纽约琐事(二)事事难料外企社招丨Dräger德尔格,行业全球领导者,15薪,六险一金,多样福利,偏爱留学生固定收益 | 从编写大纲看PPP特许经营要点——评《政府和社会资本合作项目特许经营方案编写大纲(2024年试行版)》Spring Boot 玩一玩代码混淆,防止反编译代码泄露寒假期间如何与孩子和睦相处?这有一份实用操作指南常客:不必从头开始~哪些航司会籍提供匹配呢?董宇辉卖农产品!罗永浩卖云!这有什么问题?阿里云:以后公司20%代码由通义灵码编写Windows版Bun将于本月发布,Zig编写的JavaScript运行时代码屎山噩梦加速来袭,都是AI生成代码的锅?LLM会写代码≠推理+规划!AAAI主席揭秘:代码数据质量太高|LeCun力赞分析了1.5亿行代码发现:AI编程助手降低代码质量我的健康厨房 - 我是如何控制和管理血糖的一家之煮:当Pecan决定分手时全球代码质量骤降,罪魁祸首竟是AI!1.53亿行代码深度分析报告出炉LLM巫师,代码预训练是魔杖!UIUC华人团队揭秘代码数据三大好处AI提示词要怎样编写贾扬清的500行代码,掀翻了Perplexity5.2亿的桌子?劳伦斯:久别重逢刘苏里Rust编写的Zed编辑器开源:约27万行代码、主打“高性能”离职后可以删除自己编写的软件吗?全网独一份!GPT+AI大模型教程资源……(待会删)罕见一幕:巨头开始抱团取暖?阿里1号AI「员工」上岗,007写代码助攻大厂程序员!炸掉祖传屎山代码,Java丝滑改Python《早晨的故乡》&《橱窗》OpenAI官宣开源Transformer Debugger!不用写代码,人人可以破解LLM黑箱Redis 之父自曝用 AI 写代码,锐评:LLM 有望取代 99% 的程序员!一航班空中挂“7700”代码,紧急返航!AI也造代码屎山!研究发现GitHub Copilot代码可维护性差,偏爱“无脑重写”而非重构复用已有代码AI正在使全球代码质量下降!1.53亿行代码深度分析报告出炉如何倒床就睡?这有一份失眠自救指南!
logo
联系我们隐私协议©2024 redian.news
Redian新闻
Redian.news刊载任何文章,不代表同意其说法或描述,仅为提供更多信息,也不构成任何建议。文章信息的合法性及真实性由其作者负责,与Redian.news及其运营公司无关。欢迎投稿,如发现稿件侵权,或作者不愿在本网发表文章,请版权拥有者通知本网处理。