8000 AMD/MI300X performance is lacking compared to torch.matmul · Issue #7199 · triton-lang/triton · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
AMD/MI300X performance is lacking compared to torch.matmul #7199
Open
@mobicham

Description

@mobicham

Describe the issue

Performance seems to be lacking quite a bit compared to torch.matmul and Triton Nvidia, for both fp16 and 8-bit matmul, even using max-autotuning. int8 x int8 is barely 1.2x faster on large matrices compared totorch.matmul with fp16, while it should be close to 2x.

Performance numbers

shape: M, K, N = 128, 8192, 8192
MI300X: 
Speed-up fp16 vs. fp16 | : 0.66x
Speed-up int8 vs. fp16 | : 1.59x

4090 RTX : 
Speed-up fp16 vs. fp16 | : 1.10x
Speed-up int8 vs. fp16 | : 2.18x

shape: M, K, N = 256, 16384, 16384
MI300X: 
Speed-up fp16 vs. fp16 | : 0.73x
Speed-up int8 vs. fp16 | : 1.23x

4090 RTX : 
Speed-up fp16 vs. fp16 | : 1.00x
Speed-up int8 vs. fp16 | : 2.28x

Code

import triton
import triton.language as tl
import torch

from triton.testing import do_bench
def eval_time(fct, params): 
    return do_bench(lambda: fct(**params), rep=200) 

@triton.jit
def swizzle_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
    grid_m     = tl.cdiv(M, BLOCK_SIZE_M)
    grid_n     = tl.cdiv(N, BLOCK_SIZE_N)
    width      = GROUP_SIZE_M * grid_n
    group_id   = pid // width
    group_size = tl.minimum(grid_m - group_id * GROUP_SIZE_M, GROUP_SIZE_M)
    pid_m      = group_id * GROUP_SIZE_M + (pid % group_size)
    pid_n      = (pid % width) // group_size
    return pid_m, pid_n

def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"

def get_configs_nvidia():
    configs = []
    for num_warps in [4, 8]:
        for num_stages in [4]:
            for M in [64, 128]:
                for K in [32, 64, 128, 256]:
                    for N in [32, 64, 128, 256]:
                        configs.append(
                            triton.Config(
                                {"BLOCK_SIZE_M": M, "BLOCK_SIZE_N": N, "BLOCK_SIZE_K": K},
                                num_stages=num_stages, num_warps=num_warps,
                            )
                        )

    return configs

def get_configs_amd():
    configs = []
    for num_warps in [4, 8]:
        for num_stages in [2]:
            for M in [64, 128]:
                for K in [32, 64, 128, 256]:
                    for N in [32, 64, 128, 256]:
                        for waves in [0, 2, 4, 8]:
                            configs.append(
                                triton.Config(
                                    {"BLOCK_SIZE_M": M, "BLOCK_SIZE_N": N, "BLOCK_SIZE_K": K, 
                                    "waves_per_eu": waves, "matrix_instr_nonkdim": 16,
                                    },
                                    num_stages=num_stages, num_warps=num_warps,
                                )
                            )

    return configs

@triton.autotune(
    configs=get_configs_nvidia() if is_cuda() else get_configs_amd(),
    key=['M', 'N', 'K'],
)

@triton.jit
def matmul_kernel(A_ptr, B_ptr, C_ptr,
                  M, N, K,
                  stride_am, stride_ak,
                  stride_bk, stride_bn,
                  stride_cm, stride_cn,
                  acc_dtype: tl.constexpr,
                  BLOCK_SIZE_M: tl.constexpr, 
                  BLOCK_SIZE_N: tl.constexpr, 
                  BLOCK_SIZE_K: tl.constexpr,
                  ):
    
    pid = tl.program_id(0)
    pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, 8)
    
    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offs_k = tl.arange(0, BLOCK_SIZE_K)

    a_ptrs = A_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = B_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn

    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)

    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
        b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
        acc = tl.dot(a, b, acc=acc)
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

    #Output
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(c_ptrs, acc, mask=c_mask)


def matmul_triton(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    assert a.shape[1] == b.shape[0]
    M, K = a.shape
    K, N = b.shape

    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    acc_dtype = tl.int32 if a.dtype in [torch.int8] else tl.float32
    
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        acc_dtype=acc_dtype,
    )
    return c

############################################################################################
M, K, N = 128, 8192, 8192
#M, K, N = 256, 16384, 16384

device = 'cuda:0'
dtype = torch.int8
#dtype =  torch.float8_e4m3fnuz

a = torch.randn(M, K, device=device).to(dtype)
b = torch.randn(N, K, device=device).to(dtype).T

#Test correctness
out_ref = torch.matmul(a.to(torch.float16), b.to(torch.float16))
out = matmul_triton(a, b)
assert (out_ref.float() - out.float()).abs().mean().item() < 1e-3, "output mismatch"

#Test speed
params_8bit  = {'a':a.to(dtype), 'b':b.to(dtype)}
params_16bit = {'a':a.to(torch.float16), 'b':b.to(torch.float16)}

for tag, p_torch, p_triton in [('fp16 vs. fp16', params_16bit, params_16bit), ('int8 vs. fp16', params_16bit, params_8bit)]:
    ref_time = eval_time(lambda a,b: torch.matmul(a,b), p_torch)
    out_time = eval_time(matmul_triton, p_triton)
    print(f'Speed-up {tag} | : {ref_time / out_time}')

Environment details

Triton: 3.3.1
GPU: MI300X (AMD) / 4090 RTX (Nvidia)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0