8000 Prototyping an hl.atomic opp by drisspg · Pull Request #63 · pytorch-labs/helion · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Prototyping an hl.atomic opp #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged

Prototyping an hl.atomic opp #63

merged 1 commit into from
May 23, 2025

Conversation

drisspg
Copy link
Contributor
@drisspg drisspg commented May 20, 2025

Stacked PRs:


Prototyping an hl.atomic opp

I have been playing around with a few different variants different variants

        # dy.scatter_add_(0, tile_i, local_dy_grad)
        hl.store(dy, tile_i, local_dy_grad, reduction_op="add")
        # hl.atomic_add(dy_tile, local_dy_grad)

I kinda of think that have the reduction op be semantic on store makes the most since we can map to: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor

I think we need this atomic op to back scatter_add w/ the function mode if we dont add directly to store

Code it enables

@helion.kernel(config=helion.Config(block_sizes=[32, 32]))
def mul_relu_block_back_kernel(
    x: torch.Tensor, y: torch.Tensor, dz: torch.Tensor
) -> Tuple[Tensor, Tensor]:
    # Get tensor sizes
    m, n = x.shape
    # Create output tensor for gradients
    dx = torch.empty_like(x)
    dy = torch.empty_like(y)

    # Use Helion to tile the computation
    for tile_i, tile_j in hl.tile([m, n]):
        # Get input tiles
        x_tile = x[tile_i, tile_j]
        y_tile = y[tile_i]
        dz_tile = dz[tile_i, tile_j]

        # For ReLU, gradient is 1 where input > 0, 0 otherwise
        relu_mask = (x_tile * y_tile[:, None]) > 0
        # Chain rule: dx = dz * relu_grad * y
        relu_grad = torch.where(relu_mask, 1, 0)
        dx[tile_i, tile_j] = dz_tile * relu_grad * y_tile[:, None]

        # Chain rule: dy = dz * relu_grad * x -> backwards of broadcast(sum)
        local_dy_grad = torch.sum(dz_tile * relu_grad * x_tile, dim=1)

        hl.atomic_add(dy, [tile_i,], local_dy_grad)
    return dx, dy

Output

import torch
import triton
import triton.language as tl

@triton.jit
def _mul_relu_block_back_kernel_kernel(x, y, dz, dx, dy, dx_stride_0, dx_stride_1, dy_stride_0, dz_stride_0, dz_stride_1, x_stride_0, x_stride_1, y_stride_0, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
 num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0)
 pid_0 = tl.program_id(0) % num_blocks_0
 pid_1 = tl.program_id(0) // num_blocks_0
 offset_0 = pid_0 * _BLOCK_SIZE_0
 indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
 mask_0 = indices_0 < m
 offset_1 = pid_1 * _BLOCK_SIZE_1
 indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
 mask_1 = indices_1 < n
 x_tile = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
 y_tile = tl.load(y + indices_0 * y_stride_0, mask_0, other=0)
 dz_tile = tl.load(dz + (indices_0[:, None] * dz_stride_0 + indices_1[None, :] * dz_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
 subscript = y_tile[:, None]
 v_0 = x_tile * subscript
 v_1 = 0.0
 v_2 = v_0 > v_1
 v_3 = tl.full([], 0, tl.int64)
 v_4 = tl.full([], 1, tl.int64)
 v_5 = v_4[None, None]
 v_6 = v_3[None, None]
 v_7 = tl.where(v_2, v_5, v_6)
 v_8 = v_7.to(tl.float32)
 v_9 = dz_tile * v_8
 subscript_1 = y_tile[:, None]
 v_10 = v_9 * subscript_1
 tl.store(dx + (indices_0[:, None] * dx_stride_0 + indices_1[None, :] * dx_stride_1), v_10, mask_0[:, None] & mask_1[None, :])
 v_11 = v_7.to(tl.float32)
 v_12 = dz_tile * v_11
 v_13 = v_12 * x_tile
 v_14 = tl.where(tl.broadcast_to(mask_1[None, :], [_BLOCK_SIZE_0, _BLOCK_SIZE_1]), v_13, 0)
 local_dy_grad = tl.sum(v_14, 1)
 tl.atomic_add(dy + indices_0 * dy_stride_0, local_dy_grad, mask=mask_0, sem='relaxed')

def mul_relu_block_back_kernel(x: torch.Tensor, y: torch.Tensor, dz: torch.Tensor):
 m, n = x.shape
 dx = torch.empty_like(x)
 dy = torch.empty_like(y)
 _BLOCK_SIZE_0 = 32
 _BLOCK_SIZE_1 = 32
 _mul_relu_block_back_kernel_kernel(x, y, dz, dx, dy, dx.stride(0), dx.stride(1), dy.stride(0), dz.stride(0), dz.stride(1), x.stride(0), x.stride(1), y.stride(0), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
 return (dx, dy)

def _mul_relu_block_back_kernel_make_precompiler(x: torch.Tensor, y: torch.Tensor, dz: torch.Tensor):
 m, n = x.shape
 dx = torch.empty_like(x)
 dy = torch.empty_like(y)
 _BLOCK_SIZE_0 = 32
 _BLOCK_SIZE_1 = 32
 from helion.runtime.precompile_shim import make_precompiler
 return make_precompiler(_mul_relu_block_back_kernel_kernel)(x, y, dz, dx, dy, dx.stride(0), dx.stride(1), dy.stride(0), dz.stride(0), dz.stride(1), x.stride(0), x.stride(1), y.stride(0), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)

drisspg added a commit that referenced this pull request May 20, 2025
stack-info: PR: #63, branch: drisspg/stack/5
@drisspg drisspg force-pushed the drisspg/stack/5 branch from 1bb1ec3 to 804cd24 Compare May 20, 2025 18:52
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 20, 2025
drisspg added a commit that referenced this pull request May 20, 2025
stack-info: PR: #63, branch: drisspg/stack/5
@drisspg drisspg force-pushed the drisspg/stack/5 branch from 804cd24 to 0fb0b67 Compare May 20, 2025 19:36
@drisspg drisspg changed the title Prototyping an hl.atomic opp Prototyping a hl.atomic opp May 21, 2025
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #63, branch: drisspg/stack/5
@drisspg drisspg force-pushed the drisspg/stack/5 branch from 0fb0b67 to 945cb52 Compare May 23, 2025 01:27
@drisspg drisspg changed the title Prototyping a hl.atomic opp Prototyping an hl.atomic opp May 23, 2025
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #63, branch: drisspg/stack/5
@drisspg drisspg force-pushed the drisspg/stack/5 branch from 945cb52 to bf8c898 Compare May 23, 2025 02:31
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #63, branch: drisspg/stack/5
@drisspg drisspg force-pushed the drisspg/stack/5 branch from bf8c898 to 70d6a45 Compare May 23, 2025 02:38
stack-info: PR: #63, branch: drisspg/stack/5
@drisspg drisspg force-pushed the drisspg/stack/5 branch from 70d6a45 to 0650950 Compare May 23, 2025 02:47
@drisspg drisspg marked this pull request as ready for review May 23, 2025 02:47
@drisspg drisspg merged commit 2206f3b into main May 23, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0