8000 GitHub - fxmeng/TransMLA: TransMLA: Multi-Head Latent Attention Is All You Need
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

fxmeng/TransMLA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

72 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🚀 TransMLA: Migrating GQA Models to MLA with Full DeepSeek Compatibility and Speedup

Modern large-language models often face communication bottlenecks on current hardware rather than computational limitations. Multi-head latent attention (MLA) addresses this by compressing the key-value cache using low-rank matrices, while the Absorb operation prevents the KV cache from reverting to its original size, significantly boosting both training and inference speed.

Despite the success of DeepSeek V2/V3/R1, most model vendors have heavily invested in optimizing GQA-based models and therefore lack strong incentives to retrain MLA-based models from scratch. In this paper, we introduce TransMLA, a framework that seamlessly converts any GQA-based pre-trained model (e.g., LLaMA, Qwen, Mixtral) into an MLA-based model.

📰 News

  • [2025.05.29] A new version of technical report is released: https://arxiv.org/abs/2502.07864.
  • [2025.04.28] Released TransMLA v3, successfully apply PCA across RoPE and reduce KV Cache.
  • [2025.02.16] Released the second version of the TransMLA model and usage code, compatible with RoPE and supporting Absorb operation.
  • [2025.02.13] The technical report of TransMLA is publicly available: https://huggingface.co/papers/2502.07864
  • [2025.01.02] Released the first version of the TransMLA model code, providing usage code for converting Qwen2.5 and LLaMA-3’s GQA to MLA equivalence.

🛠 Installation

git clone https://github.com/fxmeng/TransMLA.git
cd TransMLA
conda create -n transmla python=3.12.8
conda activate transmla
pip install -r requirements.txt

⚡ Quick Start

  1. Convert MHA / GQA models (e.g. Qwen2.5-7B-Instruct) into DeepSeek-MLA:
    bash scripts/convert/qwen2.5-7B-Instruct.sh
  2. Have fun playing with the converted models!
    # using `Transformers.AutoModelForCausalLM`
    import torch
    from transformers import AutoModelForCausalLM
    model = AutoModelForCausalLM.from_pretrained("outputs/qwen2_5-7B-Instruct-deepseek", trust_remote_code=True)

🔧 Advanced Usage (converter.py)

The converter.py script allows you to perform fine-grained control over RoPE removal and low-rank QKV projection towards DeepSeek-MLA. It supports:

  • Auto-search for optimal freqfold that minimizes PPL.
  • Automatic computation of collapse based on head_dim / qk_mqa_dim.
  • Evaluation of original, RoPE-removed, and final MLA models.

✅ Example Command:

python transmla/converter.py \
    --model-path meta-llama/Llama-2-7b-hf \
    --save-path ./outputs/llama2-7b-deepseek \
    --dtype bf16 \
    --device auto \
    --cal-dataset wikitext2 \
    --cal-nsamples 128 \
    --cal-max-seqlen 256 \
    --cal-batch-size 8 \
    --ppl-eval-batch-size 4 \
    --freqfold auto \
    --collapse auto \
    --qk-mqa-dim 64 \
    --q-lora-rank 512 \
    --kv-lora-rank 512

📘 Argument Details

Argument Description
--model-path Path to the base model (e.g., from HuggingFace hub).
--save-path Output path for the converted model and tokenizer.
--cal-dataset Calibration dataset: wikitext2, ptb, c4, or alpaca.
--cal-nsamples, --cal-max-seqlen, --cal-batch-size Number, max sequence length, and batch size of samples used for calibration.
--freqfold RoPE frequency folding factor, or auto to search for the best value. Note: Automatic freqfold search is only supported in single-GPU setups currently. Please set the device explicitly, for example: cuda:0.
--collapse Collapse factor for RoPE. Use auto to compute as head_dim // qk_mqa_dim. Collapse factor reduces the dim of RoPEd KV cache from head_dim to head_dim // collapse.
--qk-mqa-dim Target dimension for dec 8000 oupled RoPE.
--q-lora-rank The inner dimension for query low-rank decomposition, or None to disable low-rank decomposition for query.
--kv-lora-rank The inner dimension for key/value joint low-rank decomposition.

🧠 Tips

  • Set --freqfold auto and --collapse auto to simplify configuration. The script will automatically search for the best freqfold factor based on ppl results.
  • We recommend setting --qk-mqa-dim to 64 and --kv-lora-rank to 512 to satisfy FlashMLA's requirements on H100.

🐒 Model Zoo

Model Family Model kv-lora-rank + qk-mqa-dim freqfold Original ppl Partial RoPE ppl MLA ppl
Llama2 Llama-2-7B 512 + 64 8 5.4735 18.6373 41.6135
448 + 128 8 8.9903 25.7731
Llama3 Llama-3-8B 512 + 64 4 6.1371 12.0550 25.8047
448 + 128 4 8.3997 18.3500
Llama-3.2-1B 512 + 64 4 9.7531 16.3391 16.1404
Qwen2 Qwen2.5-7B-Instruct 512 + 64 4 8.3541 8.8902 10.0082
448 + 128 4 8.0734 9.1957
Qwen2.5-72B-Instruct 512 + 64 4 4.2687 4.9650 7.3850
448 + 128 4 4.6931 7.7172
Gemma2 gemma-2-9b-it 512 + 64 8 10.1612 11.4207 21.6260
448 + 128 8 10.9948 22.0038
320 + 256 4 10.7075 32.4387
Mistral Mistral-7B-v0.3 512 + 64 8 5.3178 7.4697 9.5830
448 + 128 8 5.5915 7.0251
Mixtral Mixtral-8x7B-v0.1 512 + 64 8 3.8422 5.6310 7.5179
448 + 128 4 4.1407 5.8374

📋 To-Do

  • Publish the technical report for the new version, detailing how TransMLA is compatible with RoPE, supports the Absorb operation.
  • Compress the dimensions of the KV cache to improve inference speed.
  • Add support for vLLM to improve inference speed.
  • Support FlashMLA.
  • Extend support to additional models (e.g., LLaMA, Mistral, Gemma2, etc.).
  • Support GTA & GLA
  • Release checkpoints.
  • Fine-tune on R1 distillation datasets.

📚 Citation

@article{meng2025transmla,
  title={TransMLA: Multi-head Latent Attention Is All You Need},
  author={Meng, Fanxu and Tang, Pingzhi and Yao, Zengwei and Zhang, Muhan},
  journal={arXiv preprint arXiv:2502.07864},
  year={2025}
}

⭐ Star History

Star History Chart

About

TransMLA: Multi-Head Latent Attention Is All You Need

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •  
0