RAG 检索模型训练:三大损失函数机制解析与对比

探讨 RAG 检索模型训练中三种损失函数:成对余弦嵌入损失、三元组边距损失和 InfoNCE 损失的机制与适用场景。

原文标题:RAG 检索模型如何学习:三种损失函数的机制解析

原文作者:数据派THU

冷月清谈:

本文深入探讨了 RAG(Retrieval-Augmented Generation)检索模型中检索嵌入模型的学习方式,并对三种常见的损失函数进行了详细的介绍和比较。文章首先肯定了检索模型在 Agent 系统中的重要性,指出即使 Agent 技术不断发展,高质量的检索仍然是提高效率、降低成本的关键。接着,文章聚焦于检索嵌入模型的学习方式,介绍了 Pairwise cosine embedding loss(成对余弦嵌入损失)、Triplet margin loss(三元组边距损失)和 InfoNCE loss 三种方法,并分析了它们的原理和适用场景。其中,成对余弦嵌入损失通过比较文本对的嵌入向量来判断匹配程度;三元组边距损失则利用锚文本、正匹配和负匹配之间的关系进行学习。InfoNCE 损失则通过区分正样本和负样本列表来优化模型。文章最后指出,三种方法各有优劣,实际应用中需要根据具体场景、数据量和算力进行选择,实验表明InfoNCE 覆盖面最广,但只要实验充分,余弦嵌入损失也能达到类似效果。

怜星夜思:

1、在实际应用中,如果数据集中正负样本比例极不平衡,会对这三种损失函数的训练效果产生什么影响?有没有什么策略可以缓解这种问题?
2、文章中提到 InfoNCE 损失覆盖面最广,那么在算力资源有限的情况下,有没有办法优化 InfoNCE 损失的计算效率,使其能够在更小的硬件上运行?
3、除了文章中提到的三种损失函数,还有没有其他适用于 RAG 检索模型学习的损失函数?它们的优缺点是什么?

原文内容

图片
来源:DeepHub IMBA
本文约1000字,建议阅读5分钟
哪种方法最好?要看具体场景、数据量和算力。


Agent 系统发展得这么快那么检索模型还重要吗?RAG 本身都已经衍生出 Agentic RAG和 Self-RAG 这些更复杂的变体了。

答案是肯定的,无论 Agent 方法在效率和推理上做了多少改进,底层还是离不开检索。检索模型越准,需要的迭代调用就越少,时间和成本都能省下来,所以训练好的检索模型依然关键。讨论 RAG 怎么用的文章铺天盖地,但真正比较检索模型学习方式的内容却不多见。

检索系统包含多个组件:检索嵌入模型、索引算法(HNSW 之类)、向量搜索机制(余弦相似度等)以及重排序模型。这篇文章只聚焦检索嵌入模型的学习方式。

本文将介绍我实验过的三种方法:Pairwise cosine embedding loss(成对余弦嵌入损失)、Triplet margin loss(三元组边距损失)、InfoNCE loss。

成对余弦嵌入损失


正样本对示例

负样本对示例

输入是一对文本加一个标签,标签标明这对文本是正匹配还是负匹配。和 MNLI 数据集里的蕴含、矛盾关系类似。

损失函数用的是余弦嵌入损失,x 和 y 分别是文本对的嵌入向量。

三元组边距损失


输入变成三个文本:一个锚文本、一个正匹配、一个负匹配。

损失函数是 Triplet Margin Loss。公式里 a 代表锚文本嵌入,p 代表正样本嵌入,n 代表负样本嵌入。

InfoNCE 损失


输入包括一个查询、一个正匹配、一组负样本列表。

损失函数采用 InfoNCE,灵感来自 M3-Embedding 论文(arxiv:2402.03216)。公式中 p* 是正样本嵌入,P' 是负样本嵌入列表,q 是查询嵌入,s(.) 表示相似度函数,比如余弦相似度。

比较


哪种方法最好?要看具体场景、数据量和算力。从我的实验来看,InfoNCE 覆盖面最广。但只要实验做得够充分、训练数据比例调得够细,余弦嵌入损失也能达到差不多的效果。三元组边距损失我没有深入探索,不过它可能是介于另外两者之间的一个折中选项。

编辑:文婧



关于我们

数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。




新浪微博:@数据派THU

微信视频号:数据派THU

今日头条:数据派THU


我想到一个比较“偷懒”的方法,就是知识蒸馏。先用一个大的、效果好的模型训练,然后用这个模型来指导一个小模型的训练。这样小模型也能学到很多有用的知识,而且推理速度更快。

楼上说得有道理!我补充一点,可以考虑使用一些数据增强技术来增加正样本的多样性,比如同义词替换、句子改写之类的。另外,也可以尝试一些专门用于处理不平衡数据的算法,比如SMOTE。

其实还有很多,比如对比学习里常用的SimCLR loss,它通过最大化同一样本不同增强视图之间的一致性来学习表征。优点是简单有效,缺点是需要设计好的数据增强策略。还有SupCon loss,它在SimCLR的基础上引入了标签信息,可以学习到更有区分性的表征。不过,这些损失函数可能更适用于无监督或半监督学习的场景。

算力有限啊,那确实得好好优化。可以试试这几个方向:1. 减少负样本的数量,但要注意保持负样本的多样性;2. 使用更小的模型,比如MobileBERT;3. 采用混合精度训练,减少显存占用;4. 使用梯度累积,在多个小batch上累积梯度,模拟更大的batch size。

学术一点来说,样本不平衡会导致模型学习到的决策边界偏移,降低泛化能力。解决这个问题,除了过采样、欠采样、调整损失函数权重外,还可以尝试使用集成学习的方法,比如Bagging或Boosting,它们可以有效地降低不平衡数据带来的影响。

我最近在看一篇论文,里面提到了一种叫做RankNet的损失函数,它主要用于排序任务,通过比较两个样本的排序关系来学习模型。感觉也可以用在 RAG 检索模型里,不过具体效果怎么样还得试一试。

正负样本比例不平衡确实是个大问题。如果负样本远多于正样本,模型可能会过于关注区分负样本,而忽略了学习真正的匹配关系。我的建议是:1. 对正样本进行过采样,增加正样本的权重;2. 尝试使用Focal Loss,它能让模型更关注难区分的样本;3. 生成一些伪正样本,但要注意质量,避免引入噪声。

在推荐系统里,有一种叫做BPR(Bayesian Personalized Ranking)的损失函数,它假设用户更喜欢正样本而不是负样本,通过最大化正负样本的排序差异来学习模型。我觉得这个思路也可以借鉴到 RAG 检索模型里,让模型更关注区分用户真正感兴趣的内容。

从算法层面看,可以考虑使用近似最近邻搜索(ANN)来加速负样本的采样过程。另外,也可以尝试一些更高效的相似度计算方法,比如局部敏感哈希(LSH)。