8000 GitHub - Parallel-Reasoning/APR: Code for Paper: Learning Adaptive Parallel Reasoning with Language Models
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Parallel-Reasoning/APR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Learning Adaptive Parallel Reasoning
with Language Models

Jiayi Pan*, Xiuyu Li*, Long Lian*, Charlie Victor Snell, Yifei Zhou,
Adam Yala, Trevor Darrell, Kurt Keutzer, Alane Suhr

UC Berkeley and UCSF    * Equal Contribution

📃 Paper💻 Code🤗 Data & Models

APR

TL;DR: We present Adaptive Parallel Reasoning (APR), a novel framework that enables language models to learn to orchestrate both serialized and parallel computations. APR trains language models to use spawn() and join() operations through end-to-end supervised training and reinforcement learning, allowing models to dynamically orchestrate their own computational workflows. APR efficiently distributes compute, reduces latency, overcomes context window limits, and achieves state‑of‑the‑art performance on complex reasoning tasks (e.g., 83.4% vs. 60.0% accuracy at 4K context on Countdown).

The full code will be released soon!

Data Preparation

Supervised Training

We use TPU-v3-128 for supervised training with a codebase building upon JAX_llama.

Please refer to the instructions for more details.

Reinforcement Learning

We present TinyRL, a simple implementation of the GRPO training framework for our experiments. TinyRL is a lightweight yet performant reinforcement learning library designed to be both easy to use and extend. It integrates with SGLang for efficient rollout. Given the small size of the model we’re training, we haven’t implemented model parallelism, so it runs on two GPUs—one for training and one for rollout

It supports asynchronous, multi-turn, multi-agent rollouts through a general rollout_fun interface, with the minimal assumption that your rollout mechanism relies on calling an OpenAI-compatible API endpoint.

def rollout_fun(server_url, prefix_list, bos_token, temperature=0.5, max_workers=32):
  pass

Please refer to the instructions for more details.

Evaluation

Important

For evaluation, SGLang needs to be patched. Remove this check in python/sglang/srt/managers/tokenizer_manager.py in your local SGLang repo:

# if (
#     obj.sampling_params.get("max_new_tokens") is not None
#     and obj.sampling_params.get("max_new_tokens") + input_token_num
#     >= self.context_len
# ):
#     raise ValueError(
#         f"Requested token count exceeds the model's maximum context length "
#         f"of {self.context_len} tokens. You requested a total of "
#         f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
#         f"tokens: {input_token_num} tokens from the input messages and "
#         f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
#         f"completion. Please reduce the number of tokens in the input "
#         f"messages or the completion to fit within the limit."
#     )

This file is located at tokenizer_manager.py

Note

sgl-project/sglang#3721 introduces an --allow-auto-truncate option that makes this patch unnecessary. Once merged, you can use that directly.

SoS+

The following command evaluates the SoS+ model on the validation set.

python -m src.eval.eval_sosp --ckpt <ckpt>  --temperature <temperature> --batch_size 256 --gens 1 --output_dir <output_dir> --num_gpus 8 --n_samples <n_samples> --budget <budget>

Where <n_samples> is the number of Best-of-N samples in inference, and <budget> is the budget for conditional generation (leave it empty if not using conditioned models). For instance, the following command evaluates the SoS+ model with 8 samples using a unconditioned checkpoint.

python -m src.eval.eval_sosp --ckpt Parallel-Reasoning/llama-sosp --temperature 1.0 --batch_size 256 --gens 1 --output_dir results/llama-sosp/ --num_gpus 8 --n_samples 8

APR

First, you need to start the SGLang server for the target model. For instance:

python -m sglang.launch_server  --served-model-name model --model-path Parallel-Reasoning/llama-apr_cond10_grpo --port 2346 --dp-size 8

Then the following command evaluates the APR model on the validation set.

python -m src.eval.eval_apr --model_name llama-apr_cond10_grpo --ckpt Parallel-Reasoning/llama-apr_cond10_grpo --temperature 1.0 --budget 10 --use_subcall_cond

which evaluates the APR model with a budget of 10 child threads and uses child thread count conditioning. Do not use --budget and --use_subcall_cond for unconditioned models.

Citation

If you find this work useful in your research, please consider citing:

@article{pan2025learning,
  title   = {Learning Adaptive Parallel Reasoning with Language Models},
  author  = {Jiayi Pan and Xiuyu Li and Long Lian and Charlie Snell and Yifei Zhou and Adam Yala and Trevor Darrell and Kurt Keutzer and Alane Suhr},
  year    = {2025},
  journal = {arXiv preprint arXiv: 2504.15466}
}

About

Code for Paper: Learning Adaptive Parallel Reasoning with Language Models

Resources

Stars

Watchers

Forks

Contributors 3

  •  
  •  
  •  
0