8000 GitHub - wdlctc/headinfer
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

wdlctc/headinfer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

HeadInfer: Memory-Efficient LLM Inference by Head-wise Offloading

License
Python
PyTorch

[paper]

Overview

HeadInfer is a memory-efficient inference framework for large language models (LLMs) that significantly reduces GPU memory consumption by leveraging a head-wise offloading strategy. Unlike traditional layer-wise KV cache offloading, HeadInfer dynamically manages attention heads, maintaining only a subset of the KV cache on the GPU while offloading the rest to CPU memory.

With HeadInfer, an 8B model can process up to 4 million tokens on a single consumer-grade GPU (e.g., RTX 4090 with 24GB VRAM), reducing GPU KV cache memory from 128GB to just 1GB without approximation.

Features

  • Head-wise KV cache offloading: Fine-grained memory optimization for long-context inference.
  • Supports million-token inference: Achieves up to 4M context length on consumer GPUs.
  • Asynchronous data transfer: Overlaps computation with offloading to minimize bottlenecks.
  • Compatible with major LLMs: Works with LLaMA, Mistral, Qwen, and more.
  • Minimal changes to existing inference frameworks: Easy integration with Hugging Face models.

Installation

Training and Evaluation Environment

conda create -yn headinfer python=3.10
conda activate headinfer

conda install -y git
conda install -y nvidia/label/cuda-12.4.0::cuda-toolkit
conda install -y nvidia::cuda-cudart-dev
conda install -y pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia

pip install transformers==4.45.2 accelerate 
pip install flash-attn --no-build-isolation
pip install git+https://github.com/wdlctc/headinfer.git@main

Usage

One-click-run with HeadInfer

python main.py

Running Inference with HeadInfer

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


from headinfer.cache import OffloadedCache
from headinfer.mp import mp_headinfer, mp_simulate_decode

model_name = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, attn_implementation="flash_attention_2").to("cuda")

# Generate text with long context
input_text = "Once upon a time in a galaxy far, far away..."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")


with torch.inference_mode():

    # patch the model
    mp_headinfer(model)
    past_key_values = OffloadedCache()

    model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, num_logits_to_keep=1)

Citation

If you find HeadInfer useful for your research, please cite:

@article{luo2025headinfer,
  title={HeadInfer: Memory-Efficient LLM Inference by Head-wise Offloading},
  author={Luo, Cheng and Cai, Zefan and Sun, Hanshi and Xiao, Jinqi and Yuan, Bo and Xiao, Wen and Hu, Junjie and Zhao, Jiawei and Chen, Beidi and Anandkumar, Anima},
  journal={arXiv preprint arXiv:2502.12574},
  year={2025}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

0