8000 GitHub - ylwangy/marc: Public repository for "The Surprising Effectiveness of Test-Time Training for Abstract Reasoning"
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
/ marc Public
forked from ekinakyurek/marc

Public repository for "The Surprising Effectiveness of Test-Time Training for Abstract Reasoning"

License

Notifications You must be signed in to change notification settings

ylwangy/marc

 
 

Repository files navigation

📋 We are still in progress making this repo clean. Use it with caution and please report errors and questions to us.

The Surprising Effectiveness of Test-Time Training for Abstract Reasoning

This repository is the official implementation of our paper:

The Surprising Effectiveness of Test-Time Training for Abstract Reasoning

Ekin Akyürek, Mehul Damani, Linlu Qiu, Han Guo, Yoon Kim, Jacob Andreas

Requirements

To install requirements, you can start a fresh conda environment, and install followings with pip:

git clone --recursive git://github.com/ekinakyurek/marc
cd marc/
# For TTT pipeline, we used a fork of torchtune library.
# You need to install it first
conda create -n arc python=3.10
# Install torchtune with my specific fork
# We need this as editable because we actually use some files
# under third_party/torchtune/recipes/ which doesn't come
# if you just do pip install
cd third_party/torchtune
pip install -e .
# install other required libraries for torchtune
pip install torch torchao --pre --upgrade --index-url https://download.pytorch.org/whl/nightly/cu121

# Then we have simple requirements can be installed as:
pip install -r requirements.txt

📋 You need download the ARC dataset from kaggle link https://www.kaggle.com/competitions/arc-prize-2024/data

📋 You can reach out finetuned models and TTT checkpoints from the following links:

Test Time Training

To train the model(s) in the paper, run this command:

# Specify data path
data_file=/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json
# Specify finetuned path
base_checkpoint_dir=/path/to/finetuned/model/folder/
# Specify where TTT adapters should be saved
ttt_folder=/path/to/ttt/folder
mkdir -p $ttt_folder


# You need show an initial config file that is compatible with torchtune configs
# This is provided in this repo
lora_config_file=configs/ttt/8B_lora_single_device.yaml
# lora_config_file=configs/ttt/8.1B_lora_single_device.yaml # for barc
# But you can override some of the variables
batch_size=2
epochs=2
learning_rate=5e-5
lora_rank=128
lora_alpha=16.0
lora_to_output=False # doesn't apply for Llama3.2 models for now.
# You can specify how many tasks you want train for.
num_tasks=100

# You can run the main script
python test_time_train.py --lora_config=$lora_config_file \
--base_checkpoint_dir \
$base_checkpoint_dir \
--experiment_folder $ttt_folder \
--data_file $data_file \
--batch_size $batch_size \
--epochs $epochs \
--num_tasks=${num_tasks} \
--lora_rank=$lora_rank \
--lora_alpha=$lora_alpha \
--lora_to_output=$lora_to_output \
--new_format # use --barc_format for barc

📋 If you are using BARC checkpoints and unmask_outputs and if unmask_outputs=True in the program arguments then you need to uncomment these lines in my torchtune clone here

📋 TTT training will save adapter checkpints under ttt_folder you specified above.

Inference

To do inference with TTT, you run predict.py

# You need to tell where predictions and submissions should be saved
tti_folder=/path/to/tti/folder
mkdir -p $tti_folder
# Tell where your Fintuned (named as base) and TTT checkpoints are
base_checkpoint_dir=/path/to/finetuned/model/folder/
ttt_folder=/path/to/ttt/folder

# if solution file is given predict will evaluate the model
solution_file=/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions_selected.json

python predict.py --experiment_folder=$tti_folder \
--pretrained_checkpoint $base_checkpoint_dir \
--lora_checkpoints_folder $ttt_folder \
--temperature $temperature \
--n_sample $n_sample \
--data_file $data_file \
--solution_file $solution_file \
--max_lora_rank=$lora_rank \
--include_n=1 \ # means we use leave-1-out prompts
--new_format

📋 For Llama-3 and Llama-3.2 we used different versions of VLLM, and the second one is not compatible with torchtune version that we use. So, we give setup instructions for vllm for llama3 and vllm for llama3-2 for reproducibiltiy. We use seperate conda environments for inference pipeline.

# For Llama3 and 3.1 models
conda create -n vllm python=3.10
pip install torchtune@git+https://github.com/ekinakyurek/vllm.git@ekin/torchtunecompat
# For Llama3.2 models
conda create -n vllmnew python=3.10
pip install torchtune@git+https://github.com/ekinakyurek/vllm.git@ekin/ekin/newvllm

About

Public repository for "The Surprising Effectiveness of Test-Time Training for Abstract Reasoning"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%
0