8000 [FFPA] Refactor FFPA-L1 Part-1✔️ by DefTruth · Pull Request #10 · xlite-dev/ffpa-attn · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[FFPA] Refactor FFPA-L1 Part-1✔️ #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified .dev/clear.sh
100644 → 100755
Empty file.
5 changes: 4 additions & 1 deletion .dev/commit-prepare.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
path=$(cd `dirname $0`; pwd)
cd $path

# cpp & python format lint
sudo apt-get update
sudo apt-get install clang-format -y
pip install pre-commit
pip install yapf
pip install cpplint
pre-commit install -c ./.dev/.pre-commit-config.yaml
pre-commit install -c ./.dev/.pre-commit-config.yaml # only lint for python
# pre-commit install -c ./.dev/.pre-commit-config-cpp.yaml # both python + cpp
4 changes: 4 additions & 0 deletions .dev/init_dev.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
export ENABLE_FFPA_ALL_STAGES=0
export ENABLE_FFPA_ALL_HEADDIM=0
export ENABLE_FFPA_AMPERE=0
export ENABLE_FFPA_HOPPER=0
7 changes: 4 additions & 3 deletions .dev/install.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
rm -rf $(find . -name __pycache__)
python3 setup.py bdist_wheel && cd dist # build pyffpa from sources
python3 -m pip install pyffpa-*-linux_x86_64.whl # pip uninstall pyffpa -y
cd .. && rm -rf build *.egg-info
python3 setup.py bdist_wheel && cd dist # build cuffpa-py from sources
python3 -m pip install cuffpa_py-*-linux_x86_64.whl # pip uninstall cuffpa-py -y
cd .. && rm -rf build *.egg-info
rm -rf $(find . -name __pycache__)
2 changes: 1 addition & 1 deletion .dev/uninstall.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python3 -m pip uninstall pyffpa -y
python3 -m pip uninstall cuffpa-py -y
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
<img src=https://img.shields.io/badge/License-GPLv3.0-turquoise.svg >
</div>

🤖 [WIP] **FFPA**: Yet antother **Faster Flash Prefill Attention** with **O(1) SRAM complexity** & **O(d/4) or O(1) register complexity** for large headdim (D > 256), almost **>1.5x** 🎉 faster than SDPA EA with or without MMA Accumulation F32 on many devices, such as NVIDIA L20, 4090, 3080 Laptop (Experimental 👀~). The FFPA kernels are modified from my repo 📖[CUDA-Learn-Notes](https://github.com/DefTruth/CUDA-Learn-Notes/tree/main/kernels/flash-attn) ![](https://img.shields.io/github/stars/DefTruth/CUDA-Learn-Notes.svg?style=social).
🤖 [WIP] **FFPA**: Yet antother **Faster Flash Prefill Attention** with **O(1) SRAM complexity** & **O(d/4) or O(1) register complexity** for large headdim (D > 256), almost **>1.5x** 🎉 faster than SDPA EA with or without MMA Accumulation F32 on many devices, such as NVIDIA L20, 4090, 3080 Laptop (Experimental 👀~). The FFPA kernels are modified from my repo 📖[CUDA-Learn-Notes](https://github.com/DefTruth/CUDA-Learn-Notes/tree/main/kernels/flash-attn) ![](https://img.shields.io/github/stars/DefTruth/CUDA-Learn-Notes.svg?style=social).

<!--
|Tensor Cores|Loop over N/D |Tile Block (Br, Bc) |MMA (m16n8k16)|
Expand All @@ -25,15 +25,15 @@
|✔️|✔️|✔️|?|
-->

NOTE: This project is still in its early development stages and currently provides a few experimental kernels and benchmarks for reference. More benchmarks data and features (FFPA **L2/L3** & more devices) will be added over time as the project continues to develop.
NOTE: This project is still in its early dev stages and now provides a few experimental kernels and benchmarks for reference. More features will be added in the future. Welcome to 🌟👆🏻star this repo to support me ~ 🎉🎉

## ©️Citations🎉🎉

```BibTeX
@misc{faster-prefill-attention@2025,
title={FFPA: Yet another Faster Flash Prefill Attention with O(1) SRAM complexity for large headdim.},
url={https://github.com/DefTruth/faster-prefill-attention},
note={Open-source software available at https://github.com/DefTruth/faster-prefill-attention},
@misc{cuffpa-py@2025,
title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},
url={https://github.com/DefTruth/cuffpa-py},
note={Open-source software available at https://github.com/DefTruth/cuffpa-py},
author={DefTruth etc},
year={2025}
}
Expand Down Expand Up @@ -80,19 +80,19 @@ By leveraging this approach, we can achieve better performance for large headdim

<div id="install"></div>

The FFPA implemented in this repo can be install as a python library, namely, `pyffpa` library (optional).
The FFPA implemented in this repo can be install as a python library, namely, `cuffpa-py` library (optional).
```bash
# clone, then, run .dev/install.sh directly or run commands as belows
git clone https://github.com/DefTruth/faster-prefill-attention.git
python3 setup.py bdist_wheel && rm -rf *.egg-info # build 'pyffpa' from sources
cd dist && python3 -m pip install pyffpa-*-linux_x86_64.whl # pip uninstall pyffpa -y
git clone https://github.com/DefTruth/cuffpa-py.git
# clone, then, run bash .dev/install.sh directly or run commands:
python3 setup.py bdist_wheel && rm -rf *.egg-info # build 'cuffpa-py' from sources
cd dist && python3 -m pip install cuffpa_py-*-linux_x86_64.whl # pip uninstall cuffpa-py -y
```

## 📖 FFPA L1 (Level 1): Benchmark 🎉🎉

<div id="L1-bench"></div>

L1: level 1, O(Brx16)~O(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, **D=320-1024(FA2 not supported 👀)**. (Notes, `*`=MMA Acc F32, `^`=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, 👇Benchmark)
L1: level 1, O(2xBrx16)≈O(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, **D=320-1024(FA2 not supported 👀)**. (Notes, `*`=MMA Acc F32, `^`=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, 👇Benchmark)

- 📚 NVIDIA RTX 3080 Laptop (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS)

Expand Down Expand Up @@ -144,7 +144,7 @@ export TORCH_CUDA_ARCH_LIST=Ada # for Ada only
export TORCH_CUDA_ARCH_LIST=Ampere # for Ampere only
cd tests && python3 test.py --B 1 --H 48 --N 8192 --show-all --D 320
```
- 📚 case: B=1, H=48, N=8192, D=320(FA2 not supported), Device=NVIDIA RTX 4090.
- 📚 case: B=1, H=48, N=8192, D=320(`FA2 not supported`), Device=NVIDIA RTX 4090.
```bash
python3 tests/test.py --B 1 --H 48 --N 8192 --show-all --D 320
-------------------------------------------------------------------------------------------------
Expand Down
24 changes: 12 additions & 12 deletions env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,54 +8,54 @@ class ENV(object):
PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))

# Enable all multi stages kernels or not (1~4), default False (1~2).
ENBALE_FFPA_ALL_STAGES = bool(int(os.environ.get("ENBALE_FFPA_ALL_STAGES", 0)))
ENABLE_FFPA_ALL_STAGES = bool(int(os.environ.get("ENABLE_FFPA_ALL_STAGES", 0)))

# Enable all headdims for FFPA kernels or not, default False.
# True, headdim will range from 32 to 1024 with step = 32, range(32, 1024, 32)
# False, headdim will range from 256 to 1024 with step = 64, range(256, 1024, 64)
ENBALE_FFPA_ALL_HEADDIM = bool(int(os.environ.get("ENBALE_FFPA_ALL_HEADDIM", 0)))
ENABLE_FFPA_ALL_HEADDIM = bool(int(os.environ.get("ENABLE_FFPA_ALL_HEADDIM", 0)))

# Enable build FFPA kernels for Ada devices (sm89, L2O, 4090, etc),
# default True.
ENBALE_FFPA_ADA = bool(int(os.environ.get("ENBALE_FFPA_ADA", 1)))
ENABLE_FFPA_ADA = bool(int(os.environ.get("ENABLE_FFPA_ADA", 1)))

# Enable build FFPA kernels for Ampere devices (sm80, A30, A100, etc),
# default True.
ENBALE_FFPA_AMPERE = bool(int(os.environ.get("ENBALE_FFPA_HOPPER", 1)))
ENABLE_FFPA_AMPERE = bool(int(os.environ.get("ENABLE_FFPA_AMPERE", 1)))

# Enable build FFPA kernels for Hopper devices (sm90, H100, H20, etc),
# default False.
ENBALE_FFPA_HOPPER = bool(int(os.environ.get("ENBALE_FFPA_HOPPER", 0)))
ENABLE_FFPA_HOPPER = bool(int(os.environ.get("ENABLE_FFPA_HOPPER", 0)))

@classmethod
def project_dir(cls):
return cls.PROJECT_DIR

@classmethod
def enable_hopper(cls):
return cls.ENBALE_FFPA_HOPPER
return cls.ENABLE_FFPA_HOPPER

@classmethod
def enable_ampere(cls):
return cls.ENBALE_FFPA_AMPERE
return cls.ENABLE_FFPA_AMPERE

@classmethod
def enable_ada(cls):
return cls.ENBALE_FFPA_ADA
return cls.ENABLE_FFPA_ADA

@classmethod
def enable_all_mutistages(cls):
return cls.ENBALE_FFPA_ALL_STAGES
return cls.ENABLE_FFPA_ALL_STAGES

@classmethod
def enable_all_headdim(cls):
return cls.ENBALE_FFPA_ALL_HEADDIM
return cls.ENABLE_FFPA_ALL_HEADDIM

@classmethod
def env_cuda_cflags(cls):
extra_env_cflags = []
if cls.enable_all_mutistages():
extra_env_cflags.append("-DENBALE_FFPA_ALL_STAGES")
extra_env_cflags.append("-DENABLE_FFPA_ALL_STAGES")
if cls.enable_all_headdim():
extra_env_cflags.append("-DENBALE_FFPA_ALL_HEADDIM")
extra_env_cflags.append("-DENABLE_FFPA_ALL_HEADDIM")
return extra_env_cflags
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_version


# package name managed by pip, which can be remove by `pip uninstall pyffpa -y`
PACKAGE_NAME = "pyffpa"
# package name managed by pip, which can be remove by `pip uninstall cuffpa-py -y`
PACKAGE_NAME = "cuffpa-py"

ext_modules = []
generator_flag = []
Expand Down Expand Up @@ -133,7 +133,9 @@ def fetch_requirements():
"tests",
"bench",
"tmp",
"pyffpa.egg-info",
"cuffpa_py.egg-info",
"__pycache__",
"third_party",
)
),
description="FFPA: Yet another Faster Flash Prefill Attention for large headdim, ~1.5x faster than SDPA EA.",
Expand Down
0