Open
Description
Describe the issue
Performance seems to be lacking quite a bit compared to torch.matmul
and Triton Nvidia, for both fp16 and 8-bit matmul, even using max-autotuning. int8 x int8 is barely 1.2x faster on large matrices compared totorch.matmul
with fp16, while it should be close to 2x.
Performance numbers
shape: M, K, N = 128, 8192, 8192
MI300X:
Speed-up fp16 vs. fp16 | : 0.66x
Speed-up int8 vs. fp16 | : 1.59x
4090 RTX :
Speed-up fp16 vs. fp16 | : 1.10x
Speed-up int8 vs. fp16 | : 2.18x
shape: M, K, N = 256, 16384, 16384
MI300X:
Speed-up fp16 vs. fp16 | : 0.73x
Speed-up int8 vs. fp16 | : 1.23x
4090 RTX :
Speed-up fp16 vs. fp16 | : 1.00x
Speed-up int8 vs. fp16 | : 2.28x
Code
import triton
import triton.language as tl
import torch
from triton.testing import do_bench
def eval_time(fct, params):
return do_bench(lambda: fct(**params), rep=200)
@triton.jit
def swizzle_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
grid_m = tl.cdiv(M, BLOCK_SIZE_M)
grid_n = tl.cdiv(N, BLOCK_SIZE_N)
width = GROUP_SIZE_M * grid_n
group_id = pid // width
group_size = tl.minimum(grid_m - group_id * GROUP_SIZE_M, GROUP_SIZE_M)
pid_m = group_id * GROUP_SIZE_M + (pid % group_size)
pid_n = (pid % width) // group_size
return pid_m, pid_n
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def get_configs_nvidia():
configs = []
for num_warps in [4, 8]:
for num_stages in [4]:
for M in [64, 128]:
for K in [32, 64, 128, 256]:
for N in [32, 64, 128, 256]:
configs.append(
triton.Config(
{"BLOCK_SIZE_M": M, "BLOCK_SIZE_N": N, "BLOCK_SIZE_K": K},
num_stages=num_stages, num_warps=num_warps,
)
)
return configs
def get_configs_amd():
configs = []
for num_warps in [4, 8]:
for num_stages in [2]:
for M in [64, 128]:
for K in [32, 64, 128, 256]:
for N in [32, 64, 128, 256]:
for waves in [0, 2, 4, 8]:
configs.append(
triton.Config(
{"BLOCK_SIZE_M": M, "BLOCK_SIZE_N": N, "BLOCK_SIZE_K": K,
"waves_per_eu": waves, "matrix_instr_nonkdim": 16,
},
num_stages=num_stages, num_warps=num_warps,
)
)
return configs
@triton.autotune(
configs=get_configs_nvidia() if is_cuda() else get_configs_amd(),
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(A_ptr, B_ptr, C_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
acc_dtype: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid = tl.program_id(0)
pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, 8)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = B_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
acc = tl.dot(a, b, acc=acc)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
#Output
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc, mask=c_mask)
def matmul_triton(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
assert a.shape[1] == b.shape[0]
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
acc_dtype = tl.int32 if a.dtype in [torch.int8] else tl.float32
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
acc_dtype=acc_dtype,
)
return c
############################################################################################
M, K, N = 128, 8192, 8192
#M, K, N = 256, 16384, 16384
device = 'cuda:0'
dtype = torch.int8
#dtype = torch.float8_e4m3fnuz
a = torch.randn(M, K, device=device).to(dtype)
b = torch.randn(N, K, device=device).to(dtype).T
#Test correctness
out_ref = torch.matmul(a.to(torch.float16), b.to(torch.float16))
out
57AA
span> = matmul_triton(a, b)
assert (out_ref.float() - out.float()).abs().mean().item() < 1e-3, "output mismatch"
#Test speed
params_8bit = {'a':a.to(dtype), 'b':b.to(dtype)}
params_16bit = {'a':a.to(torch.float16), 'b':b.to(torch.float16)}
for tag, p_torch, p_triton in [('fp16 vs. fp16', params_16bit, params_16bit), ('int8 vs. fp16', params_16bit, params_8bit)]:
ref_time = eval_time(lambda a,b: torch.matmul(a,b), p_torch)
out_time = eval_time(matmul_triton, p_triton)
print(f'Speed-up {tag} | : {ref_time / out_time}')
Environment details
Triton: 3.3.1
GPU: MI300X (AMD) / 4090 RTX (Nvidia)