8000 Support indirect loads by jansel · Pull Request #19 · pytorch-labs/helion · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Support indirect loads #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

import torch

import helion
import helion.language as hl


@helion.kernel(
config=helion.Config(
block_size=[512, 32], loop_order=[0, 1], num_warps=8, indexing="block_ptr"
)
)
def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
x_flat = x.reshape(-1) # collapse x into a single dimension
_, embedding_dim = weight.size()
out = torch.empty(
[x_flat.size(0), embedding_dim], dtype=weight.dtype, device=weight.device
)
for tile_b, tile_e in hl.tile([x_flat.size(0), embedding_dim]):
out[tile_b, tile_e] = weight[x_flat[tile_b], tile_e]
# restore the original shape
return out.view(*x.size(), embedding_dim)


def check() -> None:
from triton.testing import do_bench

num_embeddings, embedding_dim = 16, 64
x = torch.randint(0, num_embeddings, [256, 32], device="cuda", dtype=torch.int32)
weight = torch.randn([num_embeddings, embedding_dim], device="cuda")
result = embedding(x, weight)
torch.testing.assert_close(result, torch.nn.functional.embedding(x, weight))
sec = do_bench(lambda: embedding(x, weight))
baseline_sec = do_bench(lambda: torch.nn.functional.embedding(x, weight))
print(
f"Helion time: {sec:.4f}s, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
)


if __name__ == "__main__":
check()
32 changes: 24 additions & 8 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,15 +428,31 @@ def visit_Name(self, node: ast.Name) -> object:

def _subscript_slice_proxy(self, slice_node: ast.AST) -> list[object]:
assert isinstance(slice_node, ExtendedAST)
key = slice_node._type_info
if isinstance(key, SequenceType):
keys = key.unpack()
result = self.visit(slice_node)
if isinstance(result, (list, tuple)):
return [*result]
return [result]

def visit_Tuple(self, node: ast.Tuple) -> tuple[object, ...]:
return tuple([self.visit(x) for x in node.elts])

def visit_List(self, node: ast.List) -> list[object]:
return [self.visit(x) for x in node.elts]

def visit_Slice(self, node: ast.Slice) -> slice:
if node.lower is None:
lower = None
else:
keys = [key]
try:
return [x.proxy() for x in keys]
except TypeError:
raise exc.InvalidSliceType(slice_node._type_info) from None
lower = self.visit(node.lower)
if node.upper is None:
upper = None
else:
upper = self.visit(node.upper)
if node.step is None:
step = None
else:
step = self.visit(node.step)
return slice(lower, upper, step)

def visit_Assign(self, node: ast.Assign) -> None:
if len(node.targets) != 1:
Expand Down
21 changes: 20 additions & 1 deletion helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def compute_shape(
output_size.append(rdim.var)
else:
output_size.append(1)
elif isinstance(k, torch.Tensor) and k.ndim == 1:
input_size.popleft()
output_size.append(k.size(0))
else:
raise exc.InvalidIndexingType(k)
assert len(input_size) == 0, "invalid subscript"
Expand All @@ -197,7 +200,7 @@ def create(
mask_values = {}
output_size = SubscriptIndexing.compute_shape(fake_value, index)
dtype = CompileEnvironment.current().triton_index_type()
for k in index:
for n, k in enumerate(index):
if k is None:
output_idx += 1
elif isinstance(k, int):
Expand Down Expand Up @@ -232,6 +235,19 @@ def create(
else:
index_values.append(f"tl.zeros([1], {dtype}){expand}")
output_idx += 1
elif isinstance(k, torch.Tensor) and k.ndim == 1:
expand = tile_strategy.expand_str(output_size, output_idx)
ast_index = state.ast_args[1]
assert isinstance(ast_index, (list, tuple))
assert len(ast_index) == len(index)
index_var = state.codegen.lift(ast_index[n]).id
index_values.append(f"({index_var}){expand}")
if (
block_idx := TileStrategy.get_block_index(output_size[output_idx])
) is not None:
if mask := tile_strategy.mask_var(block_idx):
mask_values.setdefault(f"({mask}){expand}")
output_idx += 1
else:
raise exc.InvalidIndexingType(k)
assert len(output_size) == output_idx
Expand Down Expand Up @@ -345,6 +361,9 @@ def is_supported(state: CodegenState, index: list[object]) -> bool:
tile_strategy.offset_var(origin.origin.block_size_idx)
except NotImplementedError:
return False
if isinstance(k, torch.Tensor):
# indirect loads don't work with block_ptr
return False
return True

def validate(self) -> None:
Expand Down
11 changes: 9 additions & 2 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def merge(self, other: LocalScope | dict[str, TypeInfo]) -> LocalScope:
# Improve error message
merged = UnknownType(
merged.origin,
"Variable {k!r} has different types in control flow: {existing!s} and {v!s}",
f"Variable {k!r} has different types in control flow: {existing!s} and {v!s}",
)
self.variables[k] = merged
else:
Expand Down Expand Up @@ -162,7 +162,11 @@ def extract_locals(self) -> dict[str, TypeInfo]:
return {**self.variables}


regexp_allowed_host_ops: re.Pattern[str] = re.compile("like|new|broadcast|promote")
# Ops not matching this emit a warning if they are used in a host function
regexp_allowed_host_ops: re.Pattern[str] = re.compile(
r"like|new|broadcast|promote|view|reshape|expand|permute|strided|"
r"transpose|contiguous|unsqueeze|squeeze|zero|rand|full|fill"
)


class TypeInfo:
Expand Down Expand Up @@ -433,6 +437,9 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
output_sizes.append(env.block_sizes[k.block_size_idx].var)
elif isinstance(k, TypeNotAllowedOnDevice):
raise exc.TypePropagationError(k)
elif isinstance(k, TensorType) and k.fake_value.ndim == 1:
inputs_consumed += 1
output_sizes.append(k.fake_value.size(0))
else:
raise exc.InvalidIndexingType(k)
if inputs_consumed != self.fake_value.ndim:
Expand Down
2 changes: 1 addition & 1 deletion helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class RankMismatch(BaseError):


class InvalidIndexingType(BaseError):
message = "Expected tile/int/None/etc in tensor[...], got {0!s}."
message = "Expected tile/int/None/1D-tensor/etc in tensor[...], got {0!s}."


class RequiresTensorInAssignment(BaseError):
Expand Down
105 changes: 105 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,3 +551,108 @@ def _softmax_decomposed_make_precompiler(x: torch.Tensor):
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_softmax_decomposed_kernel)(x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, num_warps=4, num_stages=1)""",
)

def test_embedding_pointers(self):
args = (
torch.randint(0, 1024, [8, 128], device=DEVICE, dtype=torch.int32),
torch.randn([1024, 256], device=DEVICE, dtype=torch.float16),
)
self.assertExpectedInline(
run_example(
"embedding",
args,
torch.nn.functional.embedding(*args),
block_size=[1, 256],
indexing="pointer",
),
"""\
from __future__ import annotations

import torch
import triton
import triton.language as tl

@triton.jit
def _embedding_kernel(x_flat, weight, out, x_size_0, x_size_1, out_stride_0, out_stride_1, weight_stride_0, weight_stride_1, x_flat_stride_0, embedding_dim, _BLOCK_SIZE_1: tl.constexpr):
num_blocks_0 = x_size_0 * x_size_1
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
offset_0 = pid_0
indices_0 = offset_0 + tl.zeros([1], tl.int32)
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < embedding_dim
load = tl.load(x_flat + indices_0 * x_flat_stride_0, None)
load_1 = tl.load(weight + (load[:, None] * weight_stride_0 + indices_1[None, :] * weight_stride_1), mask_1[None, :], other=0)
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), load_1, mask_1[None, :])

def embedding(x: torch.Tensor, weight: torch.Tensor):
x_flat = x.reshape(-1)
_, embedding_dim = weight.size()
out = torch.empty([x_flat.size(0), embedding_dim], dtype=weight.dtype, device=weight.device)
_BLOCK_SIZE_1 = 256
_embedding_kernel[x.size(0) * x.size(1) * triton.cdiv(embedding_dim, _BLOCK_SIZE_1),](x_flat, weight, out, x.size(0), x.size(1), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return out.view(*x.size(), embedding_dim)

def _embedding_make_precompiler(x: torch.Tensor, weight: torch.Tensor):
x_flat = x.reshape(-1)
_, embedding_dim = weight.size()
out = torch.empty([x_flat.size(0), embedding_dim], dtype=weight.dtype, device=weight.device)
_BLOCK_SIZE_1 = 256
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_embedding_kernel)(x_flat, weight, out, x.size(0), x.size(1), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
)

def test_embedding_block_ptr(self):
args = (
torch.randint(0, 1024, [8, 128], device=DEVICE, dtype=torch.int32),
torch.randn([1024, 256], device=DEVICE, dtype=torch.float16),
)
self.assertExpectedInline(
run_example(
"embedding",
args,
torch.nn.functional.embedding(*args),
block_size=[8, 64],
indexing="block_ptr",
use_yz_grid=True,
),
"""\
from __future__ import annotations

import torch
import triton
import triton.language as tl

@triton.jit
def _embedding_kernel(x_flat, weight, out, out_size_0, out_size_1, x_size_0, x_size_1, x_flat_size_0, out_stride_0, out_stride_1, weight_stride_0, weight_stride_1, x_flat_stride_0, embedding_dim, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
pid_1 = tl.program_id(1)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_0 < x_size_0 * x_size_1
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < embedding_dim
load = tl.load(tl.make_block_ptr(x_flat, [x_flat_size_0], [x_flat_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero')
load_1 = tl.load(weight + (load[:, None] * weight_stride_0 + indices_1[None, :] * weight_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1], [out_stride_0, out_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), load_1, boundary_check=[0, 1])

def embedding(x: torch.Tensor, weight: torch.Tensor):
x_flat = x.reshape(-1)
_, embedding_dim = weight.size()
out = torch.empty([x_flat.size(0), embedding_dim], dtype=weight.dtype, device=weight.device)
_BLOCK_SIZE_0 = 8
_BLOCK_SIZE_1 = 64
_embedding_kernel[triton.cdiv(x.size(0) * x.size(1), _BLOCK_SIZE_0), triton.cdiv(embedding_dim, _BLOCK_SIZE_1)](x_flat, weight, out, out.size(0), out.size(1), x.size(0), x.size(1), x_flat.size(0), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return out.view(*x.size(), embedding_dim)

def _embedding_make_precompiler(x: torch.Tensor, weight: torch.Tensor):
x_flat = x.reshape(-1)
_, embedding_dim = weight.size()
out = torch.empty([x_flat.size(0), embedding_dim], dtype=weight.dtype, device=weight.device)
_BLOCK_SIZE_0 = 8
_BLOCK_SIZE_1 = 64
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_embedding_kernel)(x_flat, weight, out, out.size(0), out.size(1), x.size(0), x.size(1), x_flat.size(0), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
)
0