Open
Description
Describe the bug
I'm seeing a higher shared memory usage in flash-attn beginning with 929142b.
If you fix flash_attn.flash_attn_triton's usage of trans_b
/trans_a
and remove the assertion at
https://github.com/Dao-AILab/flash-attention/blob/f1a73d074002226c42ce65a1df170ecff9f022c0/flash_attn/flash_attn_triton.py#L1136
the following test fails on a L4/A10G:
from flash_attn.flash_attn_triton import flash_attn_func
import torch
# set seed
torch.random.manual_seed(0)
batch_size = 1
nheads = 4
d = 64
seqlen = 16
dtype = torch.bfloat16
q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5
k, v = [
torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3
for _ in range(2)
]
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)
out = flash_attn_func(q, k, v)
g = torch.randn_like(out)
out.backward(g)
File "/envs/default/lib/python3.12/site-packages/flash_attn/flash_attn_triton.py", line 959, in _flash_attn_backward
_bwd_kernel[grid](
File "/envs/default/lib/python3.12/site-packages/triton/runtime/jit.py", line 347, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/envs/default/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 209, in run
ret = self.fn.run(
^^^^^^^^^^^^
File "/envs/default/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 395, in run
return self.fn.run(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/envs/default/lib/python3.12/site-packages/triton/runtime/jit.py", line 591, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
^^^^^^^^^^
File "/envs/default/lib/python3.12/site-packages/triton/compiler/compiler.py", line 416, in __getattribute__
self._init_handles()
File "/envs/default/lib/python3.12/site-packages/triton/compiler/compiler.py", line 404, in _init_handles
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 114688, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
Previously, this would use 100352
of memory.
The kernel metadata before:
KernelMetadata(allowed_dot_input_precisions=['tf32', 'tf32x3', 'ieee'], backend_name='cuda', cluster_dims=(1, 1, 1), debug=False, default_dot_input_precision='tf32', deprecated_fp8_dtypes=[], enable_fp_fusion=True, extern_libs=[['libdevice', '/lib/python3.12/site-packages/triton/backends/nvidia/lib/libdevice.10.bc']], hash='dfc8d7459aacd5b9120aebab557291eb55b4ce0bc11e1579b6ea3feae68fdf60', max_num_imprecise_acc_default=0, maxnreg=None, name='_bwd_kernel', num_buffers_warp_spec=0, num_consumer_groups=0, num_ctas=1, num_stages=1, num_warps=8, ptx_version=None, reg_dec_producer=0, reg_inc_consumer=0, sanitize_overflow=True, shared=100352, supported_fp8_dtypes=['fp8e4b15', 'fp8e4nv', 'fp8e5'], target=GPUTarget(backend='cuda', arch=89, warp_size=32))
after:
KernelMetadata(allowed_dot_input_precisions=['tf32', 'tf32x3', 'ieee'], arch='sm89', backend_name='cuda', cluster_dims=(1, 1, 1), debug=False, default_dot_input_precision='tf32', deprecated_fp8_dtypes=[], enable_fp_fusion=True, extern_libs=[['libdevice', '/lib/python3.12/site-packages/triton/backends/nvidia/lib/libdevice.10.bc']], global_scratch_align=1, global_scratch_size=0, hash='9fbc8ad38aacb30366d0ca7c8bc8a85678f4b371bf9df8b37d9c0cd01d96343c', launch_cooperative_grid=False, max_num_imprecise_acc_default=0, maxnreg=None, name='_bwd_kernel', num_buffers_warp_spec=0, num_consumer_groups=0, num_ctas=1, num_stages=1, num_warps=8, ptx_version=None, reg_dec_producer=0, reg_inc_consumer=0, sanitize_overflow=True, shared=116736, supported_fp8_dtypes=['fp8e4b15', 'fp8e4nv', 'fp8e5'], target=GPUTarget(backend='cuda', arch=89, warp_size=32), tmem_size=0, triton_version='3.3.0')
Environment details
Triton: 3.3.0
GPU: A10G/L4