A flexible answer tree search library featuring AB-MCTS, useful for (but not limited to) LLM inference-time scaling.
import random
import treequest as tq
# Each node is associated with a user-definable `state`.
State = str
# 1. Define a function to be used for node generation.
def generate(parent_state: State | None) -> tuple[State, float]:
"""Generates new states and scores based on the parent state."""
if parent_state is None: # None represents the expansion from root.
new_state = "Initial state"
else:
new_state = f"State after {parent_state}"
score = random.random() # A score for the new state; It should be normalized to the [0, 1] range.
return new_state, score
# 2. Instantiate the algorithm and a search tree object.
algo = tq.ABMCTSA()
search_tree = algo.init_tree()
# 3. Run the search with a generation budget (10 in this case).
for _ in range(10):
search_tree = algo.step(search_tree, {'Action A': generate})
# 4. Extract the best score and state.
best_state, best_node_score = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best state: {best_state}, Score: {best_node_score}")
- Easy-to-use API with customizable node generation and node scoring logic.
- AB-MCTS-A and AB-MCTS-M, as well as Multi-LLM AB-MCTS support (See our paper for algorithm details).
- Checkpointing and resuming searches.
First, install uv
. Then you can install TreeQuest with the following command:
uv add "treequest[abmcts-m]" "https://github.com/SakanaAI/treequest.git@main"
Alternatively, you can use pip to install TreeQuest:
pip install "treequest[abmcts-m] @ git+https://github.com/SakanaAI/treequest.git"
You can use any object as a node state. You only need to define a generating function that returns a (state, score)
tuple and takes the parent state as an argument:
import dataclasses
import treequest as tq
@dataclasses.dataclass
class State:
llm_answer: str
score: float
def generate(parent_state: State | None) -> tuple[State, float]:
"""Generate a new node by calling an LLM."""
if parent_state is None:
state = initial_generation()
else:
state = refine_answer(parent_state.llm_answer, parent_state.score)
return state, state.score
def initial_generation() -> State:
"""
Call LLM API to generate an initial answer.
"""
...
def refine_answer(llm_answer: str, score: float) -> State:
"""
Call LLM API to refine an answer.
"""
...
algo = tq.ABMCTSM()
search_tree = algo.init_tree()
for i in range(20):
search_tree = algo.step(search_tree, {'Action Label': generate})
# Logging best node during the search.
if (i + 1) % 5 == 0:
best_interim_state, _ = tq.top_k(search_tree, algo, k=1)[0]
print(f"Iteration {i+1}: Best state so far = {best_interim_state}")
best_state, _ = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best Answer: {best_state.llm_answer}, Best Score: {best_state.score}")
TreeQuest supports multiple action types. For example, you can provide multiple generation functions backed by different LLMs to represent different action types:
from functools import partial
import treequest as tq
def generate(llm_name: str, parent_state=None):
"""
Call LLM API using litellm, vllm, etc., to generate a new node
"""
...
return new_state, new_score
llm_names = ["o4-mini", "gemini-2.5-pro"]
# Create dict of different actions backed by different LLMs.
generate_fns = {llm_name: partial(generate, llm_name=llm_name) for llm_name in llm_names}
algo = tq.StandardMCTS()
search_tree = algo.init_tree()
for _ in range(20):
search_tree = algo.step(search_tree, generate_fns)
The variation is not limited to LLM types; you can use different prompts, actions, scoring logic, etc. in generate_fns
.
ABMCTS-A uses node aggregation for adaptive branching:
import treequest as tq
# Instantiate the ABMCTS-A algorithm.
ab_mcts_a = tq.ABMCTSA()
search_tree = ab_mcts_a.init_tree()
for _ in range(50):
search_tree = ab_mcts_a.step(search_tree, generate_fns)
ABMCTS-M leverages PyMC's mixed modeling capabilities:
import treequest as tq
# Instantiate the ABMCTS-M algorithm.
ab_mcts_m = tq.ABMCTSM()
search_tree = ab_mcts_m.init_tree()
for _ in range(30):
search_tree = ab_mcts_m.step(search_tree, generate_fns)
NOTE: To run AB-MCTS-M, you need to install extra dependencies with the treequest[abmcts-m]
option.
- Python 3.11+
Contributions are welcome! Please see CONTRIBUTING.md for development tips.
@article{inoue2025wider,
title={Wider or Deeper? Scaling LLM Inference-Time Compute with Adaptive Branching Tree Search},
author={Inoue, Yuichi and Misaki, Kou and Imajuku, Yuki and Kuroki, So and Nakamura, Taishi and Akiba, Takuya},
journal={arXiv preprint arXiv:2503.04412},
year={2025}
}