8000 GitHub - eramax/prime-rl: prime-rl is a codebase for decentralized RL training at scale
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

prime-rl is a codebase for decentralized RL training at scale

License

Notifications You must be signed in to change notification settings

eramax/prime-rl

 
 

Repository files navigation

prime-rl - decentralized RL training at scale

prime-rl is a codebase for decentralized RL training at scale.

install

quick install

curl -sSL https://raw.githubusercontent.com/PrimeIntellect-ai/prime-rl/main/install.sh | bash

Dev

  1. Clone:
git clone git@github.com:PrimeIntellect-ai/prime-rl.git
cd prime-rl
  1. Install uv:
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
  1. Set up the environment (will default to Python 3.10)
uv sync && uv sync --extra fa

You can check that flash_attn is installed correctly by running uv run python -c "import flash_attn" and ensure no error is thrown.

  1. Precommit install
uv run pre-commit install
  1. Test
uv run pytest
  1. debug run

training

uv run torchrun --nproc_per_node=2 src/zeroband/train.py @ configs/training/debug.toml

inference

uv run python src/zeroband/infer.py @ configs/inference/debug.toml

Simple Math Run

This debug run trains deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B on the justus27/math-hendrycks-genesys-format dataset using separate inference and training processes. Depending on the number of available GPUs, we have to adjust the number of generated samples on the inference workers to match the batch size of the training process.

If you have 2 GPUs, run the following commands:

# Start inference worker
export CUDA_VISIBLE_DEVICES=0
export VLLM_WORKER_MULTIPROC_METHOD=spawn
uv run python src/zeroband/infer.py @ configs/inference/simple_math.toml --dp 1 --batch-size 64
# Start trainer
ulimit -n 4096
export CUDA_VISIBLE_DEVICES=1
uv  run torchrun src/zeroband/train.py @ configs/training/simple_math.toml

If you have 4 GPUs, run the following commands:

# Start inference workers
export CUDA_VISIBLE_DEVICES=0,1
export VLLM_WORKER_MULTIPROC_METHOD=spawn
uv run python src/zeroband/infer.py @ configs/inference/simple_math.toml --dp 2 --batch-size 32
# Start trainer
ulimit -n 4096
export CUDA_VISIBLE_DEVICES=2
uv  run torchrun src/zeroband/train.py @ configs/training/simple_math.toml

If you have 8 GPUs, run the following commands:

# Start inference workers
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
export VLLM_WORKER_MULTIPROC_METHOD=spawn
uv run python src/zeroband/infer.py @ configs/inference/simple_math.toml
# Start trainer
ulimit -n 4096
export CUDA_VISIBLE_DEVICES=6,7
uv  run torchrun --nproc_per_node=2 src/zeroband/train.py @ configs/training/simple_math.toml --data.num_workers 2

2k seq length run

on two different terminal do:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
export VLLM_WORKER_MULTIPROC_METHOD=spawn
uv run python src/zeroband/infer.py @ configs/inference/deepscaler.toml

then start the trainer

ulimit -n 4096
export CUDA_VISIBLE_DEVICES=6,7
uv  run torchrun --nproc_per_node=2 src/zeroband/train.py @ configs/training/deepscaler.toml

if running on h100 node instead of H200 you should add --train.micro_bs 4

Distributed inference

Inference supports running in multi-node multi-GPU setups supporting DP, TP and PP, and sensible combinations of these. Below are examples of how to run inference for different parallelization strategies.

Single Node (DP=1, TP=1, PP=1, requires 1 GPU)

PRIME_LOG_LEVEL=DEBUG VLLM_CONFIGURE_LOGGING=0 CUDA_VISIBLE_DEVICES=0 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --model-name Qwen/Qwen3-14B

Only TP (TP=2, PP=1, DP=1, requires 2 GPUs)

PRIME_LOG_LEVEL=DEBUG VLLM_CONFIGURE_LOGGING=0 CUDA_VISIBLE_DEVICES=0,1 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --model-name Qwen/Qwen3-14B \
	--tp 2

Only DP (DP=2, TP=1, PP=1, requires 2 GPUs)

PRIME_LOG_LEVEL=DEBUG VLLM_CONFIGURE_LOGGING=0 CUDA_VISIBLE_DEVICES=0,1 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --model-name Qwen/Qwen3-14B \
	--dp 2

Only PP (DP=1, TP=1, PP=2, requires 2 GPUs)

# Node 1
PRIME_LOG_LEVEL=DEBUG VLLM_CONFIGURE_LOGGING=0 CUDA_VISIBLE_DEVICES=0 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --model-name mikasenghaas/Qwen3-14B-0.2 \
	--pp.rank 0 \
	--pp.world-size 2 \
	--pp.iroh-seed 0 \
	--pp.iroh-peer-id ff87a0b0a3c7c0ce827e9cada5ff79e75a44a0633bfcb5b50f99307ddb26b337 \
	--seed 69
# Node 2
PRIME_LOG_LEVEL=DEBUG VLLM_CONFIGURE_LOGGING=0 CUDA_VISIBLE_DEVICES=1 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --model-name mikasenghaas/Qwen3-14B-1.2 \
	--pp.rank 1 \
	--pp.world-size 2 \
	--pp.iroh-seed 1 \
	--pp.iroh-peer-id ee1aa49a4459dfe813a3cf6eb882041230c7b2558469de81f87c9bf23bf10a03 \
	--seed 69

Note: Setting the seed here is important to ensure model shards work on the same data shards.

DP+TP (DP=2, TP=2, PP=1, requires 4 GPUs)

PRIME_LOG_LEVEL=DEBUG VLLM_CONFIGURE_LOGGING=0 CUDA_VISIBLE_DEVICES=0,1,2,3 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --model-name Qwen/Qwen3-14B \
	--dp 2 \
	--tp auto

PP+TP (DP=1, TP=2, PP=2, requires 4 GPUs)

# Node 1
PRIME_LOG_LEVEL=DEBUG VLLM_CONFIGURE_LOGGING=0 CUDA_VISIBLE_DEVICES=0,1 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --model-name mikasenghaas/Qwen3-14B-0.2 \
	--tp auto \
	--pp.rank 0 \
	--pp.world-size 2 \
	--pp.iroh-seed 0 \
	--pp.iroh-peer-id ff87a0b0a3c7c0ce827e9cada5ff79e75a44a0633bfcb5b50f99307ddb26b337 \
	--seed 69
# Node 2
PRIME_LOG_LEVEL=DEBUG VLLM_CONFIGURE_LOGGING=0 CUDA_VISIBLE_DEVICES=2,3 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --model-name mikasenghaas/Qwen3-14B-1.2 \
	--tp auto \
	--pp.rank 1 \
	--pp.world-size 2 \
	--pp.iroh-seed 1 \
	--pp.iroh-peer-id ee1aa49a4459dfe813a3cf6eb882041230c7b2558469de81f87c9bf23bf10a03 \
	--seed 69

We don't support DP+PP and so that configuration will raise an exception.

About

prime-rl is a codebase for decentralized RL training at scale

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 98.3%
  • Dockerfile 1.1%
  • Shell 0.6%
0