8000 GitHub - ZhaolinGao/REFUEL: Regressing the Relative Future: Efficient Policy Optimization for Multi-turn RLHF
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

ZhaolinGao/REFUEL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Regressing the Relative Future: Efficient Policy Optimization for Multi-turn RLHF

Zhaolin Gao, Wenhao Zhan, Jonathan D. Chang, Gokul Swamy, Kiante Brantley, Jason D. Lee, Wen Sun.

This repo covers the implementation for our paper REFUEL.

front page

Environment

torch>=2.1.0
transformers>=4.34
accelerate>=0.23
peft==0.6.2
bitsandbytes>=0.41.1
deepspeed>=0.10.3
vllm
tyro
scipy
rouge
shortuuid
jsonlines
rich
wandb
tensorboard
pandas
evaluate

Setting One

Dataset Generation

In this setting, at each iteration, we first generate the dialogues for the entire UltraInteract dataset using our policy as the assistant and Llama-3.1-70B-it as the user. We use Llama-3-8B-it as our initial policy. You can directly use our processed dataset or generate from scratch:

  1. We first generate the dialogues for the entire dataset
python ./setting_one/generate.py --model POLICY --output_dir OUTPUT_DIR --output_repo OUTPUT_REPO

You can also set num_data as a small number to test out the generation process.

  1. We generate the rewards for all the dialogues using the ArmoRM as the reward model.
python ./setting_one/rank.py --prompts INPUT_REPO --output_repo OUTPUT_REPO

We assign a -99999 for trajectories that does not have a valid reward. INPUT_REPO is the OUTPUT_REPO from step 1.

  1. The dataset go through a rigorous filtering process. We filter out the dialogues in the dataset that are longer than 2048 tokens, have the same set of responses, and do not produce a valid reward score. We tokenize the dialogue and generate a mask for each dialogue.
python ./setting_one/tokenize_masks.py --input_repo INPUT_REPO --output_repo OUTPUT_REPO

INPUT_REPO is the OUTPUT_REPO from step 2.

Training

Now, we can train our policy by running:

accelerate launch \
    --config_file accelerate_cfgs/ds_config2.yaml \
    --num_processes 8 \
    ./setting_one/refuel.py \
        --task.query_dataset DATASET_REPO \
        --task.cluster CLUSTER \
        --task.total_length 2048 \
        --task.temperature 0.8 \
        --lr 3e-7 \
        --rebel.eta 1e3 \
        --warmup_ratio 0.1 \
        --total_episodes 64000 \
        --output_dir OUTPUT_DIR \
        --per_device_train_batch_size 1 \
        --gradient_accumulation_steps 16 \
        --per_device_eval_batch_size 1 \
        --print_sample_output_freq 100 \
        --base_model PREV_POLICY

task.query_dataset: the repo of the generated dataset.

task.cluster: cluster name. We discover that different GPU/CPU/CUDA configurations could result in different logprobs. task.cluster allows us to recompute the logprobs automatically on a new cluster.

output_dir: local save directory.

base_model: the policy from the previous iteration. At the first iteration, we use meta-llama/Meta-Llama-3-8B-Instruct.

Datasets and Models

Below we include our trained models and processed datasets, as well as their winrate w.r.t. the initial policy meta-llama/Meta-Llama-3-8B-Instruct. REFUEL outperforms Llama-3.1-70B-it on dialogues with more than three turns.

Method Dataset Winrate at Turn
h = 1 h = 2 h = 3 h = 4 H = 5 avg
Llama-3.1-70B-it N/A 70.4 66.4 61.0 53.0 55.4 61.24
REFUEL (iter 1) REFUEL-Ultrainteract-Llama-3-Armo-iter_1 54.6 53.6 57.8 56.2 59.4 56.32
REFUEL (iter 2) REFUEL-Ultrainteract-Llama-3-Armo-iter_2 55.2 53.4 58.8 57.2 58.6 56.64

Setting Two

Anthropic HH

First, we process the HH dataset by filtering out dialogues with more than 5 turns, prompts more than 128 tokens, responses with more than 512 tokens.

python ./setting_two/preprocess_hh.py

The processed dataset is available at REFUEL-hh-setting-two.

Then, we train the Llama-3-8B-it with reward model FsfairX by running:

accelerate launch \
    --config_file accelerate_cfgs/deepspeed_config.yaml \
    --num_processes 8 \
    ./setting_two/anthropic_hh/refuel.py \
        --base_model meta-llama/Meta-Llama-3-8B-Instruct \
        --task.query_dataset DATASET_REPO \
        --per_device_train_batch_size 1 \
     
6BC8
   --gradient_accumulation_steps 4 \
        --per_device_eval_batch_size 1 \
        --lr 3e-7 \
        --eps 1e-8 \
        --weight_decay 1e-6 \
        --reward.kl_coef 0.05 \
        --rebel.eta 1.0 \
        --output_dir OUTPUT_DIR \
        --task.penalty_reward_value -10 \
        --print_sample_output_freq 200 \
        --task.response_length 512 \
        --offload

task.query_dataset: the repo of the processed dataset.

output_dir: local save directory.

Ultrainteract

First, we process the UltraInteract dataset by filtering out dialogues with more than 5 turns, and prompts and responses that exceed the length in Table 5 of the paper.

python ./setting_two/preprocess_ultrainteract_diff_len.py

The processed dataset is available at REFUEL-UltraInteract-setting-two.

Then, we train the Llama-3-8B-it with reward model FsfairX by running:

accelerate launch \
    --config_file accelerate_cfgs/deepspeed_config.yaml \
    --num_processes 8 \
    ./setting_two/ultrainteract/refuel.py \
        --base_model meta-llama/Meta-Llama-3-8B-Instruct \
        --task.query_dataset DATASET_REPO \
        --per_device_train_batch_size 1 \
        --gradient_accumulation_steps 4 \
        --per_device_eval_batch_size 1 \
        --wandb_project_name multiturn \
        --lr 3e-7 \
        --eps 1e-8 \
        --weight_decay 1e-6 \
        --reward.kl_coef 0 \
        --rebel.eta 1.0 \
        --output_dir OUTPUT_DIR \
        --task.penalty_reward_value -4 \
        --print_sample_output_freq 200 \
        --offload

task.query_dataset: the repo of the processed dataset.

output_dir: local save directory.

Cite

Please cite our paper if you use this implementation in your own work:

@misc{gao2024regressingrelativefutureefficient,
      title={Regressing the Relative Future: Efficient Policy Optimization for Multi-turn RLHF}, 
      author={Zhaolin Gao and Wenhao Zhan and Jonathan D. Chang and Gokul Swamy and Kianté Brantley and Jason D. Lee and Wen Sun},
      year={2024},
      eprint={2410.04612},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2410.04612}, 
}

About

Regressing the Relative Future: Efficient Policy Optimization for Multi-turn RLHF

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

0