10000 GitHub - fengzi258/Ocean-R1
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

fengzi258/Ocean-R1

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

22 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Ocean-R1: An Open and Generalizable Large Vision-Language Model enhanced by Reinforcement Learning

🎯Overview

Inspired by the robust reasoning capabilities demonstrated by DeepSeek R1 in the text domain, we seek to extend the large-scale reinforcement learning (RL) techniques that have proven effective for large language models (LLMs) to multimodal scenarios.

Given the multifaceted nature of visual perception tasks, our focus centers on two critical components: visual recognition and positioning, as well as reasoning tasks. This approach is motivated by the complementary strengths of visual perceptionβ€”which identifies and extracts visual informationβ€”and the advanced reasoning capabilities of LLMs, which are adept at problem-solving. By integrating these two modalities, we aim to address complex multimodal reasoning challenges. To achieve this, we conducted the following experiments:

  • Text-Only Training: We trained Qwen2.5-VL-3B-Instruct on the Ocean-R1 Training Text Dataset using Goal-Specific Reward Optimization (GRPO) with a rule-based reward function.

  • Visual-Only Training: We trained Qwen2.5-VL-3B-Instruct on the Ocean-R1 Training Visual Dataset using GRPO with a rule-based reward function.

  • Multimodal Training: We are training Qwen2.5-VL-3B-Instruct on a combined dataset integrating both text and visual data. The results of these experiments will be released in the near future.

This systematic exploration aims to evaluate the efficacy of GRPO in enhancing multimodal reasoning capabilities and to provide insights into the interplay between visual and textual modalities in complex reasoning tasks.

πŸ”₯We open-source our complete pipeline to foster further research in this area. We release all our codes, model, data.

Note

These data are from the open source community and are obtained through cleaning and filtering.


πŸš€ News

  • 2025-03-10: We release the Ocean-R1 repo, including codebase, model, and training datasets.

πŸ—žοΈ Our Findings

Image

  • Excellent Cross-Modal Reasoning Ability: In our experiments, training exclusively with text-only data led to varying degrees of performance improvement on reasoning-related tasks, such as geometric reasoning and mathematical problem-solving. This highlights the potential of incorporating textual inference data to enhance the VLM model's reasoning capabilities. Furthermore, improvements were also observed in counting tasks and general-purpose tasks, suggesting that the enhanced reasoning abilities can generalize to broader applications. However, this approach came at a cost: the model's performance on tasks requiring strong visual perception significantly declined. For example, in the Grounding task (refcoco/+/g), the average performance plummeted from 75.3 to 2.4. This underscores a trade-off, indicating that while GRPO can strengthen specific capabilities, it may inadvertently impair other critical aspects of the model.
  • Diverse Data Achieves Better Performances: When trained with visual data, the model exhibited substantial performance gains across a wide range of tasks, including counting, geometric reasoning, grounding, mathematical problem-solving, and general-purpose tasks. This demonstrates the importance of multimodal training in achieving balanced and comprehensive improvements across diverse domains.

πŸ“¦ Setup

conda create -n Ocean_R1 python=3.11 
conda activate Ocean_R1

bash setup.sh

Note

If you meet bug when running the script, first try align your environments with ./src/requirements.txt

πŸ”„ Training

Data Preparation

You can download our training data from Ocean_R1_collected_visual_data and Ocean_R1_collected_text_data.

GRPO

  • ./src/scripts/run_grpo_qwen2d5vl.sh
  • ./src/scripts/run_grpo_vllm_qwen2d5vl_Ocean_R1_visual_data.sh
cd src/r1-v

HF_DATASET="minglingfeng/Ocean_R1_collected_visual_data" 

export FORMAT_REWARD_FACTOR=1.0
export IS_LOCAL=False ## load_from_disk or load_dataset from huggingface: minglingfeng/Ocean_R1_collected_visual_data
export DEBUG_MODE="true"
export LOG_PATH=./src/logs/debug_qwen2p5_vl_3b_${HF_DATASET}.log
# export WANDB_API_KEY="xxxxx"
export WANDB_PROJECT="Ocean-R1"

QWEN_PATH=/global_data/mllm/minglingfeng/models/Qwen2.5-VL-3B-Instruct
OUTPUT_DIR=./src/r1-v/src/outputs/exp-Qwen2.5-VL-3B/${HF_DATASET}
if [ ! -d "$OUTPUT_DIR" ]; then
 mkdir -p "$OUTPUT_DIR"
fi
RUN_NAME=3B-$HF_DATASET
DS_CONFIG="./src/r1-v/local_scripts/zero1_no_optimizer.json"  # Note that other zero setting would meet bugs related to vllm at current stage.

# vLLM NOTE: you are expected to use X + 1 cards for X training proc and 1 vLLM proc 
# e.g., the visible devices should be 0,1,2,3,4 for 5 cards, and  --nproc_per_node="4"

CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" torchrun \
    --nproc_per_node="7" \
    --nnodes="1" \
    --node_rank="0" \
    --master_addr="127.0.0.1" \
    --master_port="12345" \
    ./src/r1-v/src/open_r1/grpo.py \
    --use_vllm true \
    --output_dir ${OUTPUT_DIR} \
    --model_name_or_path ${QWEN_PATH} \
    --dataset_name ${HF_DATASET} \
    --max_prompt_length 1024 \
    --max_completion_length 2048 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --learning_rate 1e-6 \
    --lr_scheduler_type "constant" \
    --logging_steps 1 \
    --bf16 true \
    --gradient_checkpointing true \
    --attn_implementation flash_attention_2 \
    --min_pixels 3136 \
    --max_pixels 501760 \
    --num_train_epochs 2 \
    --run_name ${RUN_NAME} \
    --save_steps 50 \
    --save_total_limit 3 \
    --save_only_model true \
    --report_to wandb \
    --temperature 1.0 \
    --vllm_device "cuda:7" \
    --vllm_gpu_memory_utilization 0.8 \
    --deepspeed ${DS_CONFIG} \
    --num_generations 7 
    # number of outputs G in grpo, reduce it would lead to faster training and smaller memory cost but higher variance 

Note

  1. To reproduce the result, keep the per_device_train_batch_size to 1 for now, as there is a revealed bug about batched training.
  2. If you meet OOM Error, you can try reduce --num_generations or set gradient_checkpointing as true.

SFT

We also provide SFT code, please follow the script and edit the config to customize the sft task.

accelerate launch --config_file src/r1-v/configs/zero2.yaml src/r1-v/src/open_r1/sft.py --config src/r1-v/configs/qwen2vl_sft_config.yaml 

πŸ§ͺ Evaluation

Note

The models are evaluated in the zero-shot setting and with an extracted matching approach, which corresponds to the rule-based reward in training stage. We provide the following evaluation scripts for reproduction.

Model SuperCLEVR GEOQA RefCOCO/+/g AVG MathVision MathVerse OlympiadBench MMMU
Qwen2.5-VL-3B-Instruct 64.1 37.0 75.3 14.4 27.6 14.6 40.5
Qwen2.5-VL-3B-Instruct-GRPO-text 66.1 38.7 2.4 17.4 31.5 14.8 43.4
Qwen2.5-VL-3B-Instruct-GRPO-vis 93.4 54.2 86.1 19.1 40.0 15.5 47.9

Counting: SuperCLEVR

cd ./src/eval/data
wget https://www.cs.jhu.edu/~zhuowan/zhuowan/SuperCLEVR/to_be_released/images.zip
unzip images.zip

# change image dir and the model path in the scripts
python ./src/eval/test_qwen2d5vl_counting_superclevr_5k.py

Geo Reasoning: GEOQA

We provide the example script to evaluate on the test set (direct answer form) of GEOQA.

# prepare images for testing
cd ./src/eval/data
git lfs install
git clone https://huggingface.co/datasets/Luckyjhg/Geo170K
cd Geo170K
unzip images.zip


# change image dir and the model path in the scripts
python ./src/eval/test_qwen2d5vl_geoqa.py

# To enable faster inference with multiple GPUs, you could also use the script in 
python ./src/eval/test_qwen2d5vl_geoqa_multigpu.py

Referring Expression Comprehension (REC): RefCOCO/+/g

  1. Download the COCO Train2014 image and unzip it, and we refer to the image dir as <your_image_root>.
  1. Download the RefCOCO/+/g Annotation files and unzip it.
# Remember to change the model path, image root, and annotation path in the script
python ./src/eval/test_qwen2d5vl_rec.py

Math: MathVision, MathVerse, and OlympiadBench

# Remember to change the model path, image root, and annotation path in the script
python ./src/eval/test_qwen2d5vl_mathvision_multigpu.py
python ./src/eval/test_qwen2d5vl_mathverse_multigpu.py
python ./src/eval/test_qwen2d5vl_olympiadbench_multigpu.py

General: MMMU

python ./src/eval/test_qwen2d5vl_mmmu.py

πŸ“‹οΈ TODO

  • Training with the combined data
  • Synthesize more high-quality and diverse multimodal data
  • Scale up to larger models and more general tasks

🀝 Acknowledgements

We sincerely thank DeepSeek, Open-R1, QwenVL, Open-R1-Multimodal, R1-V (our initial codebase), VLM-R1, CLEVR, SuperCLEVR, G-LLAVA, and RefCOCO for providing open source resources and to build the project.

πŸ“š Contributors and Citation

Contributors: Lingfeng Ming, Youwei Zhang, Yadong Li, Song Chen, Jianhua Xu, Zenan Zhou, Weipeng Chen.

If you find this work useful, please cite it as follows:

@misc{ming2025openvr1,
  author       = {Lingfeng Ming, Youwei Zhang, Yadong Li, Song Chen, Jianhua Xu, Zenan Zhou, Weipeng Chen},
  title        = {Ocean-R1: An Open and Generalizable Large Vision-Language Model enhanced by Reinforcement Learning},
  howpublished = {\url{https://github.com/fengzi258/Ocean-R1}},
  note         = {Accessed: 2025-03-10},
  year         = {2025}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  
0