Recent advancements in outcome-supervised Reinforcement Learning (RL), exemplified by OpenAI's O1 and DeepMind's R1, have demonstrated remarkable improvements in language model (LLM) reasoning capabilities. Integrating outcome-supervised RL with search engines presents another promising avenue for boosting LLM reasoning. However, outcome-supervised RL often grapples with challenges such as sparse reward, training instability, and inefficient exploration.
To address these limitations, process-supervised RL emerges as a compelling solution for enhancing Agentic RAG, offering the advantage of fine-grained rewards. We introduce ReasonRAG, a process-supervised method designed to refine Agentic RAG's strategic preferences.
Our approach consists of three key steps:
- We leverage Monte Carlo Tree Search (MCTS) to generate process-supervised rollouts, yielding rich data on process-level strategic preferences.
- We then employ Direct Preference Optimization (DPO) to effectively optimize these strategic preferences within the Agentic RAG framework.
- Finally, we construct an Agentic RAG pipeline that enables the LLM to autonomously generate queries, extract evidence, and formulate answers. We provide the dataset we constructed and links to our trained models below.
- RAG_ProGuide Dataset: https://huggingface.co/datasets/reasonrag/RAG_ProGuide
- Trained Models: Qwen2.5-7B-Instruct-ReasonRAG
- Trained Lora Models: Qwen2.5-7B-Instruct-RAG-Lora
ReasonRAG achieves superior performance on five benchmark datasets using only 5k training instances, significantly fewer than the 90k training instances required by Search-R1.
We employ process-supervised RL to enhance Agentic RAG capabilities:
- Generate process-supervised reward data.
- Policy Preference Optimization
- Agentic RAG Inference
We randomly data from PopQA, HotpotQA, 2WikimultihopQA to generate process-supervised preference data. Then, we use GPT-4o as the policy model to generate rollout data. The generated process-supervised data, namely RAG-ProGuide is available at: https://huggingface.co/datasets/reasonrag/RAG_ProGuide
Construct FlashRAG environments:
conda create --name reasonrag python=3.10.16
conda activate reasonrag
pip install flashrag-dev --pre
pip install flashrag-dev[full]
pip install vllm>=0.4.1
pip install deepspeed
Download wikidump as the corpus for retrieval
# Download wikidump
wget https://archive.org/download/enwiki-20181220/enwiki-20181220-pages-articles.xml.bz2
# Build index
python -m flashrag.retriever.index_builder \
--retrieval_method bge \
--model_path /BAAI/bge-base-en-v1.5 \
--corpus_path indexes/wiki18.jsonl \
--save_dir indexes/ \
--use_fp16 \
--max_length 512 \
--batch_size 256 \
--pooling_method mean \
--faiss_type Flat
Download QA dataset from huggingface RUC-NLPIR/FlashRAG_datasets
Note: This code generates policy preference data. You can directly use the RAG-ProGuide dataset (linked above!), or run this code to generate your own, or adapt it as needed.
python data_generation.py --dataset_name popqa --model gpt-4o
python data_generation.py --dataset_name hotpotqa --model gpt-4o
python data_generation.py --dataset_name 2wikimultihopqa --model gpt-4o
python preference_data_generation.py
# Install LLaMA Factory
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[torch,metrics]"
# Set the dataset path before prefrence optimization
llamafactory-cli train training_config/qwen_dpo.yaml
python inference.py --dataset_name hotpotqa --model $MODEL_NAME