OneGen框架解析:单次推理实现LLM生成与检索一体化
1. 项目概述OneGen让LLM一次推理同时完成生成与检索如果你正在研究或应用大语言模型尤其是涉及检索增强生成技术那你一定对RAG的流程不陌生先根据用户问题去向量数据库里搜一圈把找到的相关文档拼接到提示词里再交给LLM生成最终答案。这个“检索-拼接-生成”的流水线模式虽然有效但存在一个明显的效率瓶颈它需要两次独立的前向传播。第一次是编码查询语句以进行检索第二次才是真正的答案生成。这不仅增加了延迟也意味着你无法利用生成过程中的KV缓存来加速因为两次推理的上下文是完全割裂的。今天要聊的OneGen项目就直指这个痛点。它来自浙江大学知识引擎实验室提出了一种单次前向传播的统一生成与检索框架。简单来说它训练LLM学会在生成答案的“途中”自己决定什么时候该去“查资料”并且把“查资料”这个动作也编码成模型可以理解和生成的特殊标记。最终模型在一次推理过程中就能动态地、自主地完成检索和生成无需外部干预或额外的推理步骤。我花了一些时间研读他们的论文和代码并动手复现了部分实验。OneGen的核心思想非常巧妙它没有引入复杂的额外模块而是通过重新定义模型中每个token的“角色”并设计相应的训练目标让模型内生地掌握了检索能力。这对于需要低延迟、高吞吐量服务的RAG应用场景比如在线问答、对话系统无疑是一个很有吸引力的解决方案。接下来我会带你深入拆解OneGen的设计思路、实现细节并分享我在复现过程中踩过的坑和总结的经验。2. 核心思路拆解Token角色划分与一体化训练要理解OneGen关键在于理解它提出的“Token角色”概念。在传统的LLM视角里token无非就是输入和输出。但OneGen认为在RAG任务中token可以承担三种不同的职能这直接决定了模型该如何处理它们。2.1 三种Token角色的定义与作用生成角色这是最常见的角色对应那些需要模型预测的下一个token。例如在回答“法国的首都是哪里”时模型逐字生成“巴黎”的过程这些输出token的角色就是GEN。训练时对这类token使用标准的交叉熵损失目标是让模型生成正确的文本。上下文角色这类token为生成提供背景信息但它们本身不是模型需要预测的目标。通常这包括用户的问题、系统指令以及检索到的文档内容。在OneGen的框架里这些token被标记为CTX。它们不直接参与损失计算但通过注意力机制影响着GEN和RETtoken的生成。检索角色这是OneGen的“灵魂”。模型会在生成过程中在某些位置主动输出一个特殊的ret标记。这个标记本身不构成最终答案但它代表了一个需要被检索的“句子”或“语义单元”。例如模型在回答一个复杂问题时可能会先生成“根据 记载...”这里的ret就是一个检索token。它的向量表示即该token位置对应的隐层状态会被拿出来作为查询向量去向量数据库中寻找最相关的文档片段。注意ret是一个在词表中预先定义好的特殊token。在训练和推理前需要将它加入到模型的tokenizer中。它的作用类似于一个“占位符”或“触发器”告诉模型“在这个位置我需要一个检索到的外部知识来支撑后续的生成。”2.2 一体化训练机制如何让模型学会“何时检索”让模型学会在合适的地方插入ret标记是训练的核心挑战。OneGen采用了一种基于数据构造的监督式训练方法。以单跳问答任务为例训练数据中的每个样本除了问题和答案还包含了支撑答案的相关文档。OneGen的预处理过程会做这样一件事在答案文本中将与某个支撑文档最相关的句子替换成ret标记。举个例子原始样本问题谁写了《哈利·波特》支撑文档《哈利·波特》系列小说由英国作家J.K.罗琳创作。答案J.K.罗琳写了《哈利·波特》。OneGen处理后的训练样本输入问题 支撑文档输出ret写了《哈利·波特》。模型在训练时看到的问题是“谁写了《哈利·波特》”和支撑文档它需要生成的输出序列是“ret写了《哈利·波特》”。在这个过程中对于ret这个token它的角色是RET。OneGen对RETtoken应用对比学习损失。具体来说将ret对应的隐层状态作为锚点与正样本支撑文档的向量和负样本其他不相关文档的向量计算对比损失。这个损失的目标是让ret的向量表示与它应该检索到的正确文档在语义上尽可能接近。对于“写了《哈利·波特》”这些token它们的角色是GEN应用标准的交叉熵损失确保生成文本的流畅性和正确性。通过大量这样的样本训练模型逐渐学会了两个能力第一在需要外部知识支撑时输出ret标记第二让ret标记的向量表示能够精准地指向正确的知识片段。2.3 与现有方案的对比优势论文中的对比图清晰地展示了OneGen的效率优势。这里我结合自己的理解再展开说一下传统Pipeline RAG如开篇所述查询编码和答案生成是两次独立的前向传播无法共享计算延迟高。GritLM它虽然也是单一模型但需要在因果注意力用于生成和双向注意力用于检索之间切换本质上仍然不是完全统一的流程。并且它同样需要先对查询进行编码。OneGen它的检索动作是通过生成ret标记触发的是生成流程中的一个自然环节。这意味着单次前向传播从输入问题到最终答案只需一次模型推理。支持KV缓存由于是严格的自回归生成过程可以充分利用KV缓存技术来加速后续token的生成这对于生成长文本尤其有利。动态检索模型可以根据生成上下文动态决定检索的时机和次数而不是固定地在开头检索一次。这种设计在理论上显著降低了推理延迟和计算成本特别适合对实时性要求高的应用。3. 环境搭建与数据准备实操理论很美妙但能不能跑起来是关键。我按照官方仓库的说明在本地环境进行了搭建。这里分享一些步骤和注意事项。3.1 基础环境配置官方推荐使用Python 3.9和Conda环境兼容性最好。# 1. 克隆代码仓库 git clone https://github.com/zjunlp/OneGen cd OneGen # 2. 创建并激活Conda环境 conda create -n onegen python3.9 -y conda activate onegen # 3. 安装依赖 pip install -r requirements.txt实操心得依赖文件requirements.txt里主要包含了torch,transformers,datasets,deepspeed,faiss-gpu等核心库。如果你的网络环境导致pip install较慢或失败可以考虑使用国内镜像源例如在命令后添加-i https://pypi.tuna.tsinghua.edu.cn/simple。安装faiss-gpu时务必确认你的CUDA版本。你可以通过nvcc --version或torch.cuda.is_available()来验证。如果安装失败可以尝试先安装CPU版本faiss-cpu但后续推理性能会受影响。我使用的环境是单张A10040GB对于7B模型的训练和推理内存是足够的。如果你使用消费级显卡如RTX 4090 24GB在运行全参数训练时可能会遇到OOM内存溢出问题需要考虑使用deepspeed的ZeRO优化或后面会提到的LoRA微调虽然官方TODO中标注尚未支持。3.2 数据下载与处理OneGen在三个任务上进行了实验实体链接、单跳问答和多跳问答。论文中使用的数据可以从Google Drive下载。# 假设你已经将 train_data.tar.gz 和 eval_data.tar.gz 下载到当前目录 tar -xzvf train_data.tar.gz tar -xzvf eval_data.tar.gz # 解压后会得到 train_data 和 eval_data 文件夹 # 按照官方建议将它们移动到项目根目录下的 data 文件夹中 mv train_data data/ mv eval_data data/注意事项官方也提到训练数据实际上托管在Hugging Face Datasets上。如果你直接运行训练脚本脚本应该会自动从HF下载数据。提前下载tar包主要是为了备选和离线使用。对于实体链接任务有一个额外的关键文件预计算的实体向量。这是因为在实体链接中ret标记需要检索的是知识库中的实体。这些实体的嵌入已经预先计算好并保存为.pkl文件。你必须从提供的链接下载这个OneGen-EntityLinking-Llama2-7B-Embedding.pkl文件并在推理时指定其路径否则模型无法进行检索。这是实体链接任务与其他问答任务的一个重要区别。数据集的格式是定制化的JSONL文件每一行包含input问题上下文、output包含ret标记的答案以及可能的positive/negative文档列表。如果你想在自己的数据上训练OneGen需要将数据预处理成这种格式。4. 模型训练与推理全流程解析OneGen提供了从训练到推理、评估的完整脚本。我们以Llama2-7B为基础模型分别看看三个任务如何操作。4.1 训练流程与配置详解训练脚本统一使用deepspeed进行加速。配置文件位于workflow/目录下按任务和模型区分。# 训练实体链接模型 deepspeed train.py --workflow workflow/entity_linking/llama2.json # 训练单跳问答模型 deepspeed train.py --workflow workflow/self_rag/llama2.json # 训练多跳问答模型 deepspeed train.py --workflow workflow/multi_hop_qa/llama2.json关键的一步是理解并修改配置文件。以workflow/entity_linking/llama2.json为例我们看看几个核心参数{ “info-model”: { “model_path”: “meta-llama/Llama-2-7b-hf” “tokenizer_path”: “meta-llama/Llama-2-7b-hf” “add_ret_token”: true } “info-data”: { “dataset_name”: “zjunlp/OneGen-EntityLinking” “max_length”: 1024 } “train_args”: { “output_dir”: “./output/entity_linking_llama2” “num_train_epochs”: 3 “per_device_train_batch_size”: 2 “gradient_accumulation_steps”: 8 “learning_rate”: 2e-5 “deepspeed”: “./configs/ds_config.json” } “retriever_args”: { “n_pos_per_sent”: 2 “n_neg_per_pos”: 8 } }info-model:model_path和tokenizer_path默认指向Hugging Face模型库。如果你已经将模型下载到本地或者想使用其他模型如Qwen、Baichuan需要修改为本地路径或对应的HF模型ID。add_ret_token必须设为true这会在tokenizer中添加ret特殊token。train_args: 这里定义了训练超参数。per_device_train_batch_size是每张GPU的批大小。官方配置是针对8张A80080GB设置的所以batch_size为2通过gradient_accumulation_steps8实现等效总批大小16。如果你GPU内存较小首要任务是降低per_device_train_batch_size比如设为1。其次可以尝试减小max_length和retriever_args中的n_pos_per_sent每个句子正样本数、n_neg_per_pos每个正样本对应的负样本数这些都会影响显存占用。retriever_args: 这些参数控制着对比学习的强度。n_pos_per_sent和n_neg_per_pos越大对比学习任务越难可能效果更好但显存和计算开销也越大。在资源有限时适当调小是可行的。训练避坑指南OOM问题如果遇到CUDA out of memory不要慌。按上述顺序调整参数先降batch_size再降max_length最后考虑调整对比学习参数。也可以尝试启用deepspeed配置文件中的offload_optimizer或offload_param将优化器状态和参数卸载到CPU内存。学习率2e-5对于LLM微调是一个常见的起点。你可以根据损失曲线进行调整。如果损失下降很慢或震荡可以适当增大如果损失爆炸或NaN则需减小。日志与保存训练过程中的日志和模型检查点会保存在output_dir指定的目录。建议使用TensorBoard或Weights Biases来监控训练过程。4.2 推理流程如何运行训练好的模型训练完成后或者你直接下载了官方发布的预训练模型就可以进行推理了。推理脚本为eval.py需要指定对应的配置文件。# 实体链接推理 (需要GPU和预计算的实体向量) python eval.py --config config/eval_config/entity_linking/llama2_wo_pkl.json # 多跳问答推理 (需要GPU) python eval.py --config config/eval_config/multi_hop_qa/llama2.json配置文件解析以llama2_wo_pkl.json为例wo_pkl意思是“without pkl”即这个配置文件不指定预计算的实体向量文件你需要自己在配置中或代码里指定。{ “model”: { “model_path”: “zjunlp/OneGen-EntityLinking-Llama2-7B” “tokenizer_path”: “zjunlp/OneGen-EntityLinking-Llama2-7B” “add_ret_token”: true } “inference”: { “file”: “data/eval_data/entity_linking/test.jsonl” “output_file_path”: “./results/el_results.jsonl” “use_faiss”: true “embedding_path”: “path/to/your/OneGen-EntityLinking-Llama2-7B-Embedding.pkl” // 关键 } }model_path: 这里可以替换成你训练好的模型本地路径如./output/entity_linking_llama2/checkpoint-1000。embedding_path: 对于实体链接任务这个路径必须正确指向你下载的实体向量pkl文件。推理时模型每生成一个ret就会用其隐状态向量在这个向量库中进行最近邻搜索找到最相关的实体。use_faiss: 设置为true可以极大加速向量检索过程。Faiss是Facebook开源的向量相似度搜索库支持GPU加速。确保你安装的是faiss-gpu版本。单跳问答推理的特别说明官方仓库的Quick Start中提到单跳问答的推理和评估是结合在一起的使用了Self-RAG的评估脚本。这是因为单跳问答任务的评估方式如准确率、召回率与生成和检索的交互逻辑紧密相关。运行命令较为复杂需要指定模型路径、保存标签、检索文档数量等参数。如果你只是想看看模型生成效果可以参照其他任务的eval.py自己写一个简单的生成脚本。4.3 评估脚本量化模型性能推理完成后会生成一个包含模型预测结果的jsonl文件。接下来需要使用官方提供的评估脚本来计算各项指标。# 评估实体链接任务 bash scripts/eval_el.sh el /your/path/to/result.jsonl # 评估多跳问答任务 (HotpotQA数据集) bash scripts/eval_multi_hop_qa.sh /your/path/to/result.jsonl hotpotqa # 评估多跳问答任务 (2Wiki数据集) bash scripts/eval_multi_hop_qa.sh /your/path/to/result.jsonl 2wiki这些脚本会根据任务特定的评估指标如实体链接的精确度、多跳问答的F1值、准确率等进行计算并输出最终分数。你可以将结果与论文中报告的数据进行对比以验证复现的成功率。5. 实战经验与疑难排查在复现和实验过程中我遇到了一些典型问题这里整理出来希望能帮你少走弯路。5.1 常见问题与解决方案速查表问题现象可能原因解决方案导入错误No module named ‘onegen’项目根目录未加入Python路径在运行脚本前在终端执行export PYTHONPATH“$PYTHONPATH:$(pwd)”或将项目根目录路径添加到你的IDE运行配置中。训练时损失为NaN或突然爆炸学习率过高梯度爆炸数据中存在异常值1. 降低学习率如从2e-5降至1e-5。2. 在deepspeed配置中启用梯度裁剪 (gradient_clipping)。3. 检查训练数据格式确保ret标记被正确添加和处理。推理时模型不生成ret标记1. 模型未正确学习到检索行为。2. 推理时温度参数过高导致采样随机性大。1. 检查训练是否充分或尝试使用官方预训练模型。2. 在推理生成配置中将temperature调低如设为0.1或直接使用贪婪解码 (do_sampleFalse)。实体链接任务检索结果完全错误1. 实体向量文件路径错误或未加载。2. 推理时使用的模型与生成实体向量的模型不一致。1. 确认embedding_path配置正确且文件可读。2.确保推理用的模型和生成实体向量的模型是同一个。不同模型的向量空间不同直接混用会导致检索失效。使用Faiss时提示找不到GPUFaiss未正确编译或安装GPU版本。1. 确认安装的是faiss-gpu而非faiss-cpu。2. 可以尝试在代码中强制指定使用CPU索引index faiss.IndexFlatIP(dimension)但速度会慢很多。评估脚本运行报错结果文件格式与评估脚本预期不符缺少必要的评估依赖包。1. 仔细对照官方提供的示例结果文件格式检查你的result.jsonl文件结构。2. 安装评估所需的额外包例如对于HotpotQA评估可能需要hotpot-evaluate工具。5.2 性能优化与扩展思考推理加速OneGen支持KV缓存这是其一大优势。在实际部署时你可以结合vLLM或TGI这类高性能推理引擎进一步优化吞吐量。虽然官方TODO列出了支持vLLM但当前版本可能需要一些适配工作。自定义数据训练如果你想在自己的领域数据上应用OneGen最关键的一步是构建训练数据。你需要准备问题 证据文档 答案三元组。在答案中确定哪些部分可以直接由证据文档中的某句话支撑并将该部分替换为ret标记。这个过程可能需要一些启发式规则或借助一个轻量级模型来完成自动标注。构建负样本对于每个ret位置需要准备若干不相关的文档作为对比学习的负样本。与参数高效微调结合目前OneGen进行的是全参数微调。对于更大的模型如70B全量微调成本很高。可以期待官方未来支持LoRA或QLoRA这将大大降低训练门槛。你也可以尝试自己修改代码将LoRA适配器插入到基座模型中只训练这部分新增参数。检索器的扩展目前检索依赖于预计算的Faiss索引。对于需要实时更新知识库的场景可以考虑将检索部分替换为支持动态更新的向量数据库如Milvus、Qdrant或Weaviate。这需要修改推理代码中的检索调用部分。OneGen提出了一种优雅且高效的一体化生成-检索框架思路。它通过赋予token不同的角色巧妙地将检索动作融入自回归生成流中。从我复现的结果看在保持与Pipeline RAG相近效果的前提下其推理效率的提升是实实在在的。当然它也有其局限比如训练数据构造相对复杂对“何时检索”的学习完全依赖于监督数据。对于那些检索需求不明确或难以标注的开放域任务可能需要更精巧的设计。这个项目代码结构清晰论文阐述也相当详细非常适合作为深入理解RAG前沿技术和LLM训练技巧的一个实践案例。如果你正在构建需要低延迟知识检索的LLM应用OneGen的设计思想非常值得借鉴。