From f9a8952efa6dc70600e63ee8b3d979606e27ac4a Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 24 Apr 2025 13:30:42 -0700 Subject: [PATCH] Support persistent reductions --- examples/template_via_closure.py | 2 +- helion/_compiler/compile_environment.py | 56 ++++--- helion/_compiler/device_function.py | 2 +- helion/_compiler/host_function.py | 1 - helion/_compiler/indexing_strategy.py | 22 +++ helion/_compiler/inductor_lowering.py | 109 ++++++++++++-- helion/_compiler/reduction_strategy.py | 125 ++++++++++++++++ helion/_compiler/tile_dispatch.py | 134 +++++++++++++++++ helion/_compiler/tile_strategy.py | 152 +++++-------------- helion/_compiler/variable_origin.py | 11 +- helion/runtime/kernel.py | 1 + test/test_loops.py | 2 + test/test_matmul.py | 2 + test/test_reductions.py | 185 ++++++++++++++++++++++++ test/test_type_propagation.py | 160 ++++++++++---------- 15 files changed, 732 insertions(+), 232 deletions(-) create mode 100644 helion/_compiler/reduction_strategy.py create mode 100644 helion/_compiler/tile_dispatch.py create mode 100644 test/test_reductions.py diff --git a/examples/template_via_closure.py b/examples/template_via_closure.py index 0c7b05d5..4e64c8c2 100644 --- a/examples/template_via_closure.py +++ b/examples/template_via_closure.py @@ -52,7 +52,7 @@ def check(n: int, k: int, m: int) -> None: x = torch.randn([n, k], device="cuda", dtype=torch.float16) y = torch.randn([k, m], device="cuda", dtype=torch.float16) bias = torch.randn([1, m], device="cuda", dtype=torch.float16) - # The epilogue can use the captured bias tensor that is implicitly lifted to an arg + # The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg result = matmul_with_epilogue(x, y, lambda acc, tile: torch.relu(acc + bias[tile])) torch.testing.assert_close( result, diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 9375b044..26801702 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -76,25 +76,20 @@ def finalize_config_spec(self) -> None: self.config_spec.block_size_specs, shape ) - def allocate_block_size(self, numel: int | torch.SymInt) -> int: + def allocate_block_size( + self, size: int | torch.SymInt, *, reduction: bool = False + ) -> int: idx = len(self.block_sizes) - if isinstance(numel, torch.SymInt): - numel_expr = numel._sympy_() - else: - numel_expr = sympy.sympify(numel) - with self.shape_env.ignore_fresh_unbacked_symbols(): - sym = self.shape_env.create_unbacked_symint() - assert isinstance(sym._sympy_(), sympy.Symbol) self.block_sizes.append( info := BlockSizeInfo( block_size_idx=idx, - numel=numel_expr, - var=sym, + size=size, + var=self.create_block_var( + f"block_size_{idx}" if not reduction else f"rdim_{idx}" + ), + reduction=reduction, ) ) - self.debug_shape_renames[sym._sympy_()] = sympy.Symbol( - info.name(), integer=True - ) from .host_function import HostFunction from .host_function import SymbolOrigin @@ -104,6 +99,31 @@ def allocate_block_size(self, numel: int | torch.SymInt) -> int: ) return idx + def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInfo: + for rdim in self.block_sizes: + if rdim.reduction and rdim.size == size: + return rdim + rdim_idx = self.allocate_block_size(size, reduction=True) + return self.block_sizes[rdim_idx] + + def create_block_var(self, debug_name: str) -> torch.SymInt: + with self.shape_env.ignore_fresh_unbacked_symbols(): + sym = self.shape_env.create_unbacked_symint() + # self.shape_env.guards.append( + # ShapeGuard( + # sympy.Ne(sym._sympy_(), 0), + # SLoc("create_block_var", current_location().format()), + # True, + # ) + # ) + # TODO(jansel): I was hoping the above would work, seems like some decomps require concrete values + # to determine zeroness. Figure out a better way to do this. + # pyre-ignore[29] + self.shape_env.var_to_val[sym._sympy_()] = sympy.Integer(64) + assert isinstance(sym._sympy_(), sympy.Symbol) + self.debug_shape_renames[sym._sympy_()] = sympy.Symbol(debug_name, integer=True) + return sym + def to_fake(self, obj: object, origin: Origin) -> object: if isinstance(obj, torch.Tensor): return self._to_fake_tensor(obj, origin.to_source()) @@ -233,15 +253,17 @@ class BlockSizeInfo(typing.NamedTuple): """ block_size_idx: int - numel: sympy.Expr + size: torch.SymInt | int var: torch.SymInt + reduction: bool + + @property + def numel(self) -> sympy.Expr: + return _to_sympy(self.size) def symbol(self) -> sympy.Symbol: return self.var._sympy_() - def name(self) -> str: - return f"block_size{self.block_size_idx}" - def warning(warning: exc.BaseWarning | type[exc.BaseWarning]) -> None: CompileEnvironment.current().errors.add(warning) diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 88d06895..947eb43c 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -148,7 +148,7 @@ def __init__(self, name: str, config: Config) -> None: self.dce_vars: list[str] = [] from .indexing_strategy import IndexingStrategy - from .tile_strategy import TileStrategyDispatch + from .tile_dispatch import TileStrategyDispatch self.tile_strategy: TileStrategyDispatch = TileStrategyDispatch(self, config) self.indexing_strategy: IndexingStrategy = IndexingStrategy.select(config) diff --git a/helion/_compiler/host_function.py b/helion/_compiler/host_function.py index 3db13a0a..0596650c 100644 --- a/helion/_compiler/host_function.py +++ b/helion/_compiler/host_function.py @@ -97,7 +97,6 @@ def __init__(self, fn: types.FunctionType, fake_args: list[object]) -> None: env.errors.raise_if_errors() env.finalize_config_spec() self.device_ir = lower_to_device_ir(self) - # TODO(jansel): assert we don't have any extra decorators # TODO(jansel): check type annotations for hl.constexpr/hl.specialize diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index 0fa7b23d..7524c4b9 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -13,6 +13,7 @@ from .ast_extension import expr_from_string from .compile_environment import CompileEnvironment from .host_function import HostFunction +from .tile_strategy import TileStrategy from .variable_origin import BlockSizeOrigin if TYPE_CHECKING: @@ -172,6 +173,15 @@ def compute_shape( output_size.append(k) else: output_size.append(1) + elif isinstance(k, slice) and str(k) == "slice(None, None, None)": + size = input_size.popleft() + if size != 1: + rdim = CompileEnvironment.current().allocate_reduction_dimension( + size + ) + output_size.append(rdim.var) + else: + output_size.append(1) else: raise exc.InvalidIndexingType(k) assert len(input_size) == 0, "invalid subscript" @@ -211,6 +221,18 @@ def create( else: val = state.device_function.literal_expr(k) index_values.append(f"tl.full([1], {val}, {dtype}){expand}") + elif isinstance(k, slice) and str(k) == "slice(None, None, None)": + expand = tile_strategy.expand_str(output_size, output_idx) + if fake_value.size(len(index_values)) != 1: + block_idx = TileStrategy.get_block_index(output_size[output_idx]) + assert block_idx is not None + index_var = tile_strategy.index_var(block_idx) + index_values.append(f"({index_var}){expand}") + if mask := tile_strategy.mask_var(block_idx): + mask_values.setdefault(f"({mask}){expand}") + else: + index_values.append(f"tl.zeros([1], {dtype}){expand}") + output_idx += 1 else: raise exc.InvalidIndexingType(k) assert len(output_size) == output_idx diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 7ada040d..f2b8c213 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import contextlib import dataclasses import functools from operator import getitem @@ -39,12 +40,14 @@ from .ast_extension import statement_from_string from .compile_environment import CompileEnvironment from .tile_strategy import TileStrategy -from .tile_strategy import TileStrategyDispatch if TYPE_CHECKING: + from collections.abc import Iterator + from .. import Config from .device_function import DeviceFunction from .generate_ast import GenerateAST + from .tile_dispatch import TileStrategyDispatch CodegenHandler = Callable[["GraphInterpreter", torch.fx.Node], object] @@ -63,12 +66,7 @@ def prepare_graph_lowerings(gm: torch.fx.GraphModule) -> None: "output", }, node.op if node.op == "call_function": - prior_buffers = len(graph_lowering.buffers) node.meta["lowering"] = prepare_node_lowering(graph_lowering, node) - if len(graph_lowering.buffers) > prior_buffers + 1: - raise InductorLoweringError( - f"Lowering {node.op} resulted in {len(graph_lowering.buffers) - prior_buffers} buffers, expected 1." - ) def prepare_node_lowering( @@ -117,6 +115,7 @@ def convert_arg(arg: Node) -> TensorBox: ) ) + prior_buffers = len(graph_lowering.buffers) input_names: list[str] = [] result = graph_lowering.call_function( # pyre-ignore[6] @@ -133,6 +132,16 @@ def convert_arg(arg: Node) -> TensorBox: raise InductorLoweringError( f"Lowering {node.target} returned buffer type {type(buffer)}, expected ComputedBuffer: {buffer}" ) + + new_buffers = graph_lowering.buffers[prior_buffers:] + if len(new_buffers) != 1: + # TODO(jansel): handle multiple buffers more generally + raise InductorLoweringError( + f"Lowering {node.op} resulted in {len(new_buffers)} buffers, expected 1." + ) + + assert new_buffers[0] is buffer + assert isinstance(buffer, ComputedBuffer) if isinstance(buffer.data, Pointwise): return PointwiseLowering(buffer, input_names) if isinstance(buffer.data, Reduction): @@ -169,21 +178,33 @@ def input_asts(self, ctx: GraphInterpreter, node: torch.fx.Node) -> list[ast.AST assert len(input_asts) == len(self.input_names) return input_asts + @staticmethod + def input_fake_tensors(node: torch.fx.Node) -> list[torch.Tensor]: + def visit(n: torch.fx.Node) -> torch.fx.Node: + if isinstance(val := n.meta["val"], torch.Tensor): + result.append(val) + return n + + result: list[torch.Tensor] = [] + map_arg((node.args, node.kwargs), visit) + return result + def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: raise NotImplementedError( f"codegen not implemented for {type(self).__name__}: {self.buffer}" ) - -@functools.cache -def dummy_gm() -> torch.fx.GraphModule: - return torch.fx.symbolic_trace(lambda: None) - - -class PointwiseLowering(InductorLowering): - def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: + @contextlib.contextmanager + def install_kernel_handlers( + self, ctx: GraphInterpreter, node: torch.fx.Node + ) -> Iterator[None]: with ( - inductor_config.patch("triton.codegen_upcast_to_fp32", False), + inductor_config.patch( + { + "triton.codegen_upcast_to_fp32": False, + "split_reductions": False, + } + ), # pyre-ignore[19] V.set_graph_handler( GraphLowering( @@ -201,6 +222,17 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: TritonKernel({}, features=SIMDKernelFeatures([], sympy.S.One)) ), ): + yield + + +@functools.cache +def dummy_gm() -> torch.fx.GraphModule: + return torch.fx.symbolic_trace(lambda: None) + + +class PointwiseLowering(InductorLowering): + def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: + with self.install_kernel_handlers(ctx, node): indices = [ sympy.Symbol(f"i{n}") for n in range(len(self.buffer.data.ranges)) ] @@ -209,7 +241,52 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: class ReductionLowering(InductorLowering): - pass + def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: + reduction = self.buffer.data + assert isinstance(reduction, Reduction) + indices = [sympy.Symbol(f"i{n}") for n in range(len(reduction.ranges))] + reduction_indices = [ + sympy.Symbol(f"i{n}") + for n in range(len(indices), len(indices) + len(reduction.reduction_ranges)) + ] + with self.install_kernel_handlers(ctx, node): + # codegen the pointwise part before reduction + output_name = _unpack_opsvalue( + self.buffer.data.inner_fn(indices, reduction_indices) + ) + + reduction_ranges = reduction.reduction_ranges + if len(reduction_ranges) != 1: + # TODO(jansel): can this happen? + raise NotImplementedError("multiple reduction dimensions") + reduction_var = reduction_ranges[0] + assert isinstance(reduction_var, sympy.Symbol) + + block_idx = TileStrategy.get_block_index(reduction_var) + assert block_idx is not None + strategy = ctx.cg.device_function.tile_strategy.get_reduction_strategy( + block_idx + ) + + inputs = self.input_fake_tensors(node) + if len(inputs) != 1: + # TODO(jansel): combine multiple inputs into a single fake value + raise NotImplementedError("reductions with >1 input") + + # TODO(jansel): find a better way to get dim + (dim,) = [ + i + for i, v in enumerate(inputs[0].shape) + if TileStrategy.get_block_index(v) == block_idx + ] + + return strategy.codegen_reduction( + output_name, + reduction.reduction_type, + dim, + inputs[0], + node.meta["val"], + ) @dataclasses.dataclass diff --git a/helion/_compiler/reduction_strategy.py b/helion/_compiler/reduction_strategy.py new file mode 100644 index 00000000..4ac6a10b --- /dev/null +++ b/helion/_compiler/reduction_strategy.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import sympy + +from ..autotuner.config_fragment import integer_power_of_two +from .ast_extension import expr_from_string +from .ast_extension import statement_from_string +from .compile_environment import CompileEnvironment +from .device_function import DeviceFunction +from .host_function import HostFunction +from .tile_strategy import CompactedShape +from .tile_strategy import TileStrategy + +if TYPE_CHECKING: + import ast + + import torch + + from .inductor_lowering import CodegenState + + +class ReductionStrategy(TileStrategy): + def __init__( + self, + fn: DeviceFunction, + block_index: int, + ) -> None: + super().__init__( + fn=fn, + block_indices=[block_index], + ) + + @property + def block_index(self) -> int: + return self.block_indices[0] + + def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]: + return shapes + + def codegen_reduction( + self, + input_name: str, + reduction_type: str, + dim: int, + fake_input: torch.Tensor, + fake_output: torch.Tensor, + ) -> ast.AST: + raise NotImplementedError + + def call_reduction_function( + self, input_name: str, reduction_type: str, dim: int + ) -> str: + if reduction_type in {"sum", "max", "min", "argmax", "argmin"}: + # TODO(jansel): some of the above have different NaN handling than torch, we may want to take the triton_helpers version + return f"tl.{reduction_type}({input_name}, {dim})" + if reduction_type == "prod": + return f"triton_helpers.prod({input_name}, {dim})" + raise NotImplementedError(f"Unsupported reduction type: {reduction_type}") + + +class PersistentReductionStrategy(ReductionStrategy): + def __init__( + self, + fn: DeviceFunction, + block_index: int, + ) -> None: + super().__init__( + fn=fn, + block_index=block_index, + ) + env = CompileEnvironment.current() + numel = env.block_sizes[block_index].numel + if isinstance(numel, (int, sympy.Integer)) and integer_power_of_two(int(numel)): + self._mask_var: str | None = None + else: + self._mask_var = self.fn.new_var(f"mask_{block_index}", dce=True) + self._block_size_var: str = self.fn.new_var(f"_RDIM_SIZE_{block_index}") + self.offset_vars[block_index] = "0" + + def mask_var(self, block_idx: int) -> str | None: + assert block_idx == self.block_index + return self._mask_var + + def block_size_var(self, block_idx: int) -> str | None: + assert block_idx == self.block_index + return self._block_size_var + + def codegen_preamble(self, state: CodegenState) -> None: + env = CompileEnvironment.current() + block_idx = self.block_index + numel = env.block_sizes[block_idx].numel + index_var = self.index_var(block_idx) + mask_var = self._mask_var + block_size_var = self._block_size_var + state.codegen.host_statements.append( + statement_from_string( + f"{block_size_var} = triton.next_power_of_2({HostFunction.current().sympy_expr(numel)})" + ) + ) + state.device_function.constexpr_arg(block_size_var) + state.add_statement( + f"{index_var} = tl.arange(0, {block_size_var}).to({env.triton_index_type()})" + ) + if mask_var is not None: + state.add_statement( + f"{mask_var} = {index_var} < {self.fn.sympy_expr(numel)}" + ) + + def codegen_reduction( + self, + input_name: str, + reduction_type: str, + dim: int, + fake_input: torch.Tensor, + fake_output: torch.Tensor, + ) -> ast.AST: + expr = self.call_reduction_function(input_name, reduction_type, dim) + size = [*fake_input.size()] + size.pop(dim) + if [*fake_output.size()] == size: + return expr_from_string(expr) + shape = DeviceFunction.current().tile_strategy.shape_str([*fake_output.size()]) + return expr_from_string(f"tl.reshape({expr}, {shape})") diff --git a/helion/_compiler/tile_dispatch.py b/helion/_compiler/tile_dispatch.py new file mode 100644 index 00000000..f9175603 --- /dev/null +++ b/helion/_compiler/tile_dispatch.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import collections +from typing import TYPE_CHECKING + +from helion._compiler.compile_environment import CompileEnvironment +from helion._compiler.reduction_strategy import PersistentReductionStrategy +from helion._compiler.reduction_strategy import ReductionStrategy +from helion._compiler.tile_strategy import CompactedShape +from helion._compiler.tile_strategy import FlattenedTileStrategy +from helion._compiler.tile_strategy import NDTileStrategy +from helion._compiler.tile_strategy import TileStrategy + +if TYPE_CHECKING: + import ast + from collections.abc import Sequence + + import torch + + from helion import Config + from helion._compiler.device_function import DeviceFunction + from helion._compiler.inductor_lowering import CodegenState + + SymIntLike = torch.SymInt | int + ShapeLike = Sequence[SymIntLike] + + +class TileStrategyDispatch: + def __init__( + self, + fn: DeviceFunction, + config: Config, + ) -> None: + super().__init__() + env = CompileEnvironment.current() + specs = env.config_spec.block_size_specs + block_size_idx = iter( + [bs.block_size_idx for bs in env.block_sizes if not bs.reduction] + ) + block_sizes = config.block_sizes + loop_orders = collections.deque(config.loop_orders) + assert len(block_sizes) == len(specs) + self.block_index_to_strategy: dict[int, TileStrategy] = {} + self.strategies: list[TileStrategy] = [] + for spec, block_size in zip(specs, block_sizes): + block_indices = [next(block_size_idx) for _ in range(len(spec))] + if spec.allow_reorder: + loop_order = loop_orders.popleft() + else: + loop_order = [*range(len(spec))] + strategy_cls = ( + FlattenedTileStrategy if isinstance(block_size, int) else NDTileStrategy + ) + strategy = strategy_cls(fn, block_indices, spec, block_size, loop_order) + self.strategies.append(strategy) + for idx in block_indices: + self.block_index_to_strategy[idx] = strategy + assert not loop_orders + rdims = [bs.block_size_idx for bs in env.block_sizes if bs.reduction] + for rdim_index in rdims: + # TODO(jansel): add looped reduction config choices + strategy = PersistentReductionStrategy(fn, rdim_index) + self.strategies.append(strategy) + self.block_index_to_strategy[rdim_index] = strategy + + def offset_var(self, block_idx: int) -> str: + return self.block_index_to_strategy[block_idx].offset_var(block_idx) + + def index_var(self, block_idx: int) -> str: + return self.block_index_to_strategy[block_idx].index_var(block_idx) + + def mask_var(self, block_idx: int) -> str | None: + return self.block_index_to_strategy[block_idx].mask_var(block_idx) + + def need_mask(self, block_idx: int) -> bool: + return self.block_index_to_strategy[block_idx].mask_var(block_idx) is not None + + def block_size_var(self, block_idx: int) -> str | None: + return self.block_index_to_strategy[block_idx].block_size_var(block_idx) + + def codegen_grid(self, state: CodegenState, block_indices: list[int]) -> None: + strategy = self.block_index_to_strategy[block_indices[0]] + assert strategy.block_indices == block_indices + strategy.codegen_grid(state) + for other_strategy in self.strategies: + if other_strategy is not strategy: + other_strategy.codegen_preamble(state) + + def codegen_device_loop( + self, state: CodegenState, block_indices: list[int] + ) -> tuple[ast.For, list[ast.AST]]: + strategy = self.block_index_to_strategy[block_indices[0]] + assert strategy.block_indices == block_indices + return strategy.codegen_device_loop(state) + + def _compact_shape(self, shapes: ShapeLike) -> list[CompactedShape]: + compacted_shapes = [] + for idx, shape in enumerate(shapes): + block_idx = TileStrategy.get_block_index(shape) + if block_idx is None: + compacted_shapes.append( + CompactedShape(self.strategies[0].fn.literal_expr(shape), [idx], []) + ) + else: + block_size = self.block_size_var(block_idx) + if block_size is None: + block_size = "1" + compacted_shapes.append(CompactedShape(block_size, [idx], [block_idx])) + for strategy in self.strategies: + compacted_shapes = strategy.compact_shape(compacted_shapes) + return compacted_shapes + + def shape_str(self, shape: ShapeLike) -> str: + compacted_shapes = self._compact_shape(shape) + result = [s.size_str for s in compacted_shapes] + return f"[{', '.join(result)}]" + + def expand_str(self, shape: ShapeLike, i: int) -> str: + assert 0 <= i < len(shape) + compacted_shapes = self._compact_shape(shape) + result = [] + for dim in compacted_shapes: + if i in dim.user_indices: + result.append(":") + else: + result.append("None") + if result == [":"]: + return "" + return f"[{', '.join(result)}]" + + def get_reduction_strategy(self, block_idx: int) -> ReductionStrategy: + strategy = self.block_index_to_strategy[block_idx] + assert isinstance(strategy, ReductionStrategy) + return strategy diff --git a/helion/_compiler/tile_strategy.py b/helion/_compiler/tile_strategy.py index cbf37af8..9ebf4076 100644 --- a/helion/_compiler/tile_strategy.py +++ b/helion/_compiler/tile_strategy.py @@ -2,7 +2,6 @@ import ast import collections -import dataclasses import functools import itertools import operator @@ -31,7 +30,6 @@ from collections.abc import Sequence from ..autotuner.config_spec import BlockSizeSpec - from ..runtime.config import Config from .device_function import DeviceFunction from .inductor_lowering import CodegenState @@ -40,27 +38,17 @@ ShapeLike = Sequence[SymIntLike] -@dataclasses.dataclass class TileStrategy: _fn: weakref.ReferenceType[DeviceFunction] block_indices: list[int] - spec: BlockSizeSpec - block_size: list[int] | int - loop_order: list[int] def __init__( self, fn: DeviceFunction, block_indices: list[int], - spec: BlockSizeSpec, - block_size: list[int] | int, - loop_order: list[int], ) -> None: self._fn = weakref.ref(fn) self.block_indices = block_indices - self.spec = spec - self.block_size = block_size - self.loop_order = loop_order self.index_vars: dict[int, str] = { block_idx: self.fn.new_var(f"indices_{block_idx}", dce=True) for block_idx in block_indices @@ -76,16 +64,6 @@ def fn(self) -> DeviceFunction: assert fn is not None return fn - def _reorder(self, block_indices: list[_T]) -> list[_T]: - if len(block_indices) <= 1: - return block_indices - order = self.loop_order - assert len(order) == len(block_indices), ( - f"Invalid order length: {len(order)} != {len(block_indices)}" - ) - assert {*order} == {*range(len(order))}, f"Invalid permutation: {order}" - return [block_indices[i] for i in reversed(order)] - def offset_var(self, block_idx: int) -> str: return self.offset_vars[block_idx] @@ -104,6 +82,9 @@ def codegen_grid(self, state: CodegenState) -> None: def codegen_device_loop(self, state: CodegenState) -> tuple[ast.For, list[ast.AST]]: raise NotImplementedError + def codegen_preamble(self, state: CodegenState) -> None: + """Called after a *different* strategy has been used to generate the grid.""" + def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]: raise NotImplementedError @@ -120,7 +101,39 @@ def get_block_index(cls, size: int | torch.SymInt | sympy.Expr) -> int | None: return None -class FlattenedTileStrategy(TileStrategy): +class BlockSizeTileStrategy(TileStrategy): + spec: BlockSizeSpec + block_size: list[int] | int + loop_order: list[int] + + def __init__( + self, + fn: DeviceFunction, + block_indices: list[int], + spec: BlockSizeSpec, + block_size: list[int] | int, + loop_order: list[int], + ) -> None: + super().__init__( + fn=fn, + block_indices=block_indices, + ) + self.spec = spec + self.block_size = block_size + self.loop_order = loop_order + + def _reorder(self, block_indices: list[_T]) -> list[_T]: + if len(block_indices) <= 1: + return block_indices + order = self.loop_order + assert len(order) == len(block_indices), ( + f"Invalid order length: {len(order)} != {len(block_indices)}" + ) + assert {*order} == {*range(len(order))}, f"Invalid permutation: {order}" + return [block_indices[i] for i in reversed(order)] + + +class FlattenedTileStrategy(BlockSizeTileStrategy): """Collapse all dimensions into single flat iteration space.""" block_size: int @@ -291,7 +304,7 @@ def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]: return output -class NDTileStrategy(TileStrategy): +class NDTileStrategy(BlockSizeTileStrategy): """Do up to 3D tiling using the kernel grid.""" block_size: list[int] @@ -435,97 +448,6 @@ def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]: return shapes -class TileStrategyDispatch: - def __init__( - self, - fn: DeviceFunction, - config: Config, - ) -> None: - specs = CompileEnvironment.current().config_spec.block_size_specs - block_size_idx = itertools.count() - block_sizes = config.block_sizes - loop_orders = collections.deque(config.loop_orders) - assert len(block_sizes) == len(specs) - self.block_index_to_strategy: dict[int, TileStrategy] = {} - self.strategies: list[TileStrategy] = [] - for spec, block_size in zip(specs, block_sizes): - block_indices = [next(block_size_idx) for _ in range(len(spec))] - if spec.allow_reorder: - loop_order = loop_orders.popleft() - else: - loop_order = [*range(len(spec))] - strategy_cls = ( - FlattenedTileStrategy if isinstance(block_size, int) else NDTileStrategy - ) - strategy = strategy_cls(fn, block_indices, spec, block_size, loop_order) - self.strategies.append(strategy) - for idx in block_indices: - self.block_index_to_strategy[idx] = strategy - assert not loop_orders - - def offset_var(self, block_idx: int) -> str: - return self.block_index_to_strategy[block_idx].offset_var(block_idx) - - def index_var(self, block_idx: int) -> str: - return self.block_index_to_strategy[block_idx].index_var(block_idx) - - def mask_var(self, block_idx: int) -> str | None: - return self.block_index_to_strategy[block_idx].mask_var(block_idx) - - def need_mask(self, block_idx: int) -> bool: - return self.block_index_to_strategy[block_idx].mask_var(block_idx) is not None - - def block_size_var(self, block_idx: int) -> str | None: - return self.block_index_to_strategy[block_idx].block_size_var(block_idx) - - def codegen_grid(self, state: CodegenState, block_indices: list[int]) -> None: - strategy = self.block_index_to_strategy[block_indices[0]] - assert strategy.block_indices == block_indices - return strategy.codegen_grid(state) - - def codegen_device_loop( - self, state: CodegenState, block_indices: list[int] - ) -> tuple[ast.For, list[ast.AST]]: - strategy = self.block_index_to_strategy[block_indices[0]] - assert strategy.block_indices == block_indices - return strategy.codegen_device_loop(state) - - def _compact_shape(self, shapes: ShapeLike) -> list[CompactedShape]: - compacted_shapes = [] - for idx, shape in enumerate(shapes): - block_idx = TileStrategy.get_block_index(shape) - if block_idx is None: - compacted_shapes.append( - CompactedShape(self.strategies[0].fn.literal_expr(shape), [idx], []) - ) - else: - block_size = self.block_size_var(block_idx) - if block_size is None: - block_size = "1" - compacted_shapes.append(CompactedShape(block_size, [idx], [block_idx])) - for strategy in self.strategies: - compacted_shapes = strategy.compact_shape(compacted_shapes) - return compacted_shapes - - def shape_str(self, shape: ShapeLike) -> str: - compacted_shapes = self._compact_shape(shape) - result = [s.size_str for s in compacted_shapes] - return f"[{', '.join(result)}]" - - def expand_str(self, shape: ShapeLike, i: int) -> str: - assert 0 <= i < len(shape) - compacted_shapes = self._compact_shape(shape) - result = [] - for dim in compacted_shapes: - if i in dim.user_indices: - result.append(":") - else: - result.append("None") - if result == [":"]: - return "" - return f"[{', '.join(result)}]" - - class CompactedShape(NamedTuple): size_str: str user_indices: list[int] diff --git a/helion/_compiler/variable_origin.py b/helion/_compiler/variable_origin.py index 3aea1cd5..a24a0f08 100644 --- a/helion/_compiler/variable_origin.py +++ b/helion/_compiler/variable_origin.py @@ -136,7 +136,6 @@ class WrappedOrigin(Origin): """Keeps track of where a variable came from.""" value: Origin - key: int | str def base_type(self) -> type[Origin]: return self.value.base_type() @@ -164,6 +163,8 @@ def to_source(self) -> Source: @dataclasses.dataclass class GetItemOrigin(WrappedOrigin): + key: int | str + def host_str(self) -> str: return f"{self.value.host_str()}[{self.key!r}]" @@ -230,3 +231,11 @@ def host_str(self) -> str: ) assert host_str is not None return host_str + + +@dataclasses.dataclass +class ReductionDimensionOrigin(Origin): + rdim_idx: int + + def host_str(self) -> str: + raise NotImplementedError diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 37cfe086..db05227f 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -430,6 +430,7 @@ def _function_key(fn: Kernel, obj: types.FunctionType) -> object: sorted((k, fn._specialization_key(v)) for k, v in x.items()) ), types.FunctionType: _function_key, + types.BuiltinFunctionType: lambda fn, x: x, } diff --git a/test/test_loops.py b/test/test_loops.py index 515777bd..198bcc27 100644 --- a/test/test_loops.py +++ b/test/test_loops.py @@ -28,6 +28,8 @@ def device_loop_3d(x: torch.Tensor) -> torch.Tensor: class TestLoops(TestCase): + maxDiff = 16384 + def test_pointwise_device_loop(self): args = (torch.randn([512, 512], device=DEVICE),) code, result = code_and_output( diff --git a/test/test_matmul.py b/test/test_matmul.py index ccd7d4e9..1d0e55d4 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -66,6 +66,8 @@ def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: class TestMatmul(TestCase): + maxDiff = 16384 + def test_matmul0(self): args = ( torch.randn([128, 128], device=DEVICE, dtype=torch.float32), diff --git a/test/test_reductions.py b/test/test_reductions.py new file mode 100644 index 00000000..eb778854 --- /dev/null +++ b/test/test_reductions.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from typing import Callable +import unittest + +from expecttest import TestCase +import torch + +import helion +from helion._testing import code_and_output +import helion.language as hl + + +@helion.kernel() +def sum_kernel(x: torch.Tensor) -> torch.Tensor: + n, _m = x.size() + out = torch.empty( + [n], + dtype=x.dtype, + device=x.device, + ) + for tile_n in hl.tile(n): + out[tile_n] = x[tile_n, :].sum(-1) + return out + + +@helion.kernel() +def sum_kernel_keepdims(x: torch.Tensor) -> torch.Tensor: + _n, m = x.size() + out = torch.empty( + [1, m], + dtype=x.dtype, + device=x.device, + ) + for tile_m in hl.tile(m): + out[:, tile_m] = x[:, tile_m].sum(0, keepdim=True) + return out + + +@helion.kernel(config={"block_sizes": [1]}) +def reduce_kernel( + x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], out_dtype=torch.float32 +) -> torch.Tensor: + n, _m = x.size() + out = torch.empty( + [n], + dtype=out_dtype, + device=x.device, + ) + for tile_n in hl.tile(n): + out[tile_n] = fn(x[tile_n, :], dim=-1) + return out + + +class TestReductions(TestCase): + maxDiff = 16384 + + def test_sum(self): + args = (torch.randn([512, 512], device="cuda"),) + code, output = code_and_output(sum_kernel, args, block_size=1) + torch.testing.assert_close(output, args[0].sum(-1)) + self.assertExpectedInline( + code, + """\ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +@triton.jit +def _sum_kernel_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, _m, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + mask_1 = indices_1 < _m + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) + sum_1 = tl.sum(load, 1) + tl.store(out + indices_0 * out_stride_0, sum_1, None) + +def sum_kernel(x: torch.Tensor): + n, _m = x.size() + out = torch.empty([n], dtype=x.dtype, device=x.device) + _RDIM_SIZE_1 = triton.next_power_of_2(_m) + _sum_kernel_kernel[n,](x, out, out.stride(0), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return out""", + ) + + def test_sum_keepdims(self): + args = (torch.randn([512, 512], device="cuda"),) + code, output = code_and_output( + sum_kernel_keepdims, args, block_size=16, index="block_ptr" + ) + torch.testing.assert_close(output, args[0].sum(0, keepdim=True)) + self.assertExpectedInline( + code, + """\ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +@triton.jit +def _sum_kernel_keepdims_kernel(x, out, out_stride_1, x_stride_0, x_stride_1, m, _n, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) + mask_0 = indices_0 < m + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + mask_1 = indices_1 < _n + load = tl.load(x + (indices_1[:, None] * x_stride_0 + indices_0[None, :] * x_stride_1), mask_1[:, None] & mask_0[None, :], other=0) + sum_1 = tl.reshape(tl.sum(load, 0), [1, _BLOCK_SIZE_0]) + tl.store(out + indices_0[None, :] * out_stride_1, sum_1, mask_0[None, :]) + +def sum_kernel_keepdims(x: torch.Tensor): + _n, m = x.size() + out = torch.empty([1, m], dtype=x.dtype, device=x.device) + _BLOCK_SIZE_0 = 16 + _RDIM_SIZE_1 = triton.next_power_of_2(_n) + _sum_kernel_keepdims_kernel[triton.cdiv(m, _BLOCK_SIZE_0),](x, out, out.stride(1), x.stride(0), x.stride(1), m, _n, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return out""", + ) + + def test_argmin_argmax(self): + for fn in (torch.argmin, torch.argmax): + args = (torch.randn([512, 512], device="cuda"), fn, torch.int64) + code, output = code_and_output( + reduce_kernel, args, block_size=16, index="block_ptr" + ) + torch.testing.assert_close(output, args[1](args[0], dim=-1)) + self.assertExpectedInline( + code, + """\ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +@triton.jit +def _reduce_kernel_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, n, _m, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) + mask_0 = indices_0 < n + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + mask_1 = indices_1 < _m + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + argmax = tl.argmax(load, 1) + tl.store(out + indices_0 * out_stride_0, argmax, mask_0) + +def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], out_dtype=torch.float32): + n, _m = x.size() + out = torch.empty([n], dtype=out_dtype, device=x.device) + _BLOCK_SIZE_0 = 16 + _RDIM_SIZE_1 = triton.next_power_of_2(_m) + _reduce_kernel_kernel[triton.cdiv(n, _BLOCK_SIZE_0),](x, out, out.stride(0), x.stride(0), x.stride(1), n, _m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return out""", + ) + + def test_reduction_functions(self): + for fn in ( + torch.amax, + torch.amin, + torch.prod, + torch.sum, + ): + args = (torch.randn([512, 512], device="cuda"), fn) + output = reduce_kernel(*args) + torch.testing.assert_close(output, fn(args[0], dim=-1)) + + @unittest.skip("need to fix handling of multiple inductor buffers") + def test_mean(self): + args = (torch.randn([512, 512], device="cuda"), torch.mean, torch.int64) + code, output = code_and_output( + reduce_kernel, args, block_size=8, index="block_ptr" + ) + torch.testing.assert_close(output, args[1](args[0], dim=-1)) + self.assertExpectedInline( + code, + """ + """, + ) diff --git a/test/test_type_propagation.py b/test/test_type_propagation.py index ece58eee..04e22e9b 100644 --- a/test/test_type_propagation.py +++ b/test/test_type_propagation.py @@ -55,14 +55,14 @@ def add(x, y): # Attribute: TensorAttributeType AttributeOrigin(value=SourceOrigin(location=), key='size') # Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=) for tile in hl.tile(out.size()): - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=) # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) - # BinOp: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # BinOp: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([y_size0, x_size1], torch.int32) GetItemOrigin(value=SourceOrigin(location=), key=0) # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([y_size0, x_size1], torch.int32) GetItemOrigin(value=SourceOrigin(location=), key=1) # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) out[tile] = x[tile] + y[tile] @@ -71,20 +71,20 @@ def add(x, y): def device_ir(): # File: .../basic_kernels.py:13 in add, code: out[tile] = x[tile] + y[tile] x: "i32[s17, s27]" = helion_language__tracing_ops__host_tensor('x') - block_size0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size0') - block_size1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size1') - load: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size0, block_size1]); x = None + block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') + block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1') + load: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, block_size_1]); x = None # File: .../basic_kernels.py:13 in add, code: out[tile] = x[tile] + y[tile] y: "i32[s17, s27]" = helion_language__tracing_ops__host_tensor('y') - load_1: "i32[u0, u1]" = helion_language_memory_ops_load(y, [block_size0, block_size1]); y = None + load_1: "i32[u0, u1]" = helion_language_memory_ops_load(y, [block_size_0, block_size_1]); y = None # File: .../basic_kernels.py:13 in add, code: out[tile] = x[tile] + y[tile] add: "i32[u0, u1]" = torch.ops.aten.add.Tensor(load, load_1); load = load_1 = None # File: .../basic_kernels.py:13 in add, code: out[tile] = x[tile] + y[tile] out: "i32[s17, s27]" = helion_language__tracing_ops__host_tensor('out') - store = helion_language_memory_ops_store(out, [block_size0, block_size1], add); out = block_size0 = block_size1 = add = store = None + store = helion_language_memory_ops_store(out, [block_size_0, block_size_1], add); out = block_size_0 = block_size_1 = add = store = None return None""", ) @@ -111,25 +111,25 @@ def torch_ops_pointwise(x, y): # Attribute: TensorAttributeType AttributeOrigin(value=SourceOrigin(location=), key='size') # Name: TensorType([x_size0], torch.int32) SourceOrigin(location=) for tile in hl.tile(out.size()): - # Subscript: TensorType([block_size0], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0], torch.int32) SourceOrigin(location=) # Name: SequenceType([TileIndexType(0)]) SourceOrigin(location=) - # Call: TensorType([block_size0], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.sigmoid) AttributeOrigin(value=GlobalOrigin(name='torch'), key='sigmoid') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') - # Call: TensorType([block_size0], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.add) AttributeOrigin(value=GlobalOrigin(name='torch'), key='add') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') - # Call: TensorType([block_size0], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.sin) AttributeOrigin(value=GlobalOrigin(name='torch'), key='sin') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') - # Subscript: TensorType([block_size0], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0], torch.int32) ArgumentOrigin(name='x') # Name: SequenceType([TileIndexType(0)]) SourceOrigin(location=) - # Call: TensorType([block_size0], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.cos) AttributeOrigin(value=GlobalOrigin(name='torch'), key='cos') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') - # Subscript: TensorType([block_size0], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0], torch.int32) DeviceOrigin(location=) # Name: TensorType([y_size0], torch.int32) ArgumentOrigin(name='y') # Name: SequenceType([TileIndexType(0)]) SourceOrigin(location=) out[tile] = torch.sigmoid(torch.add(torch.sin(x[tile]), torch.cos(y[tile]))) @@ -138,15 +138,15 @@ def torch_ops_pointwise(x, y): def device_ir(): # File: .../basic_kernels.py:21 in torch_ops_pointwise, code: out[tile] = torch.sigmoid(torch.add(torch.sin(x[tile]), torch.cos(y[tile]))) x: "i32[s77]" = helion_language__tracing_ops__host_tensor('x') - block_size0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size0') - load: "i32[u0]" = helion_language_memory_ops_load(x, [block_size0]); x = None + block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') + load: "i32[u0]" = helion_language_memory_ops_load(x, [block_size_0]); x = None # File: .../basic_kernels.py:21 in torch_ops_pointwise, code: out[tile] = torch.sigmoid(torch.add(torch.sin(x[tile]), torch.cos(y[tile]))) sin: "f32[u0]" = torch.ops.aten.sin.default(load); load = None # File: .../basic_kernels.py:21 in torch_ops_pointwise, code: out[tile] = torch.sigmoid(torch.add(torch.sin(x[tile]), torch.cos(y[tile]))) y: "i32[s17]" = helion_language__tracing_ops__host_tensor('y') - load_1: "i32[u0]" = helion_language_memory_ops_load(y, [block_size0]); y = None + load_1: "i32[u0]" = helion_language_memory_ops_load(y, [block_size_0]); y = None # File: .../basic_kernels.py:21 in torch_ops_pointwise, code: out[tile] = torch.sigmoid(torch.add(torch.sin(x[tile]), torch.cos(y[tile]))) cos: "f32[u0]" = torch.ops.aten.cos.default(load_1); load_1 = None @@ -160,7 +160,7 @@ def device_ir(): # File: .../basic_kernels.py:21 in torch_ops_pointwise, code: out[tile] = torch.sigmoid(torch.add(torch.sin(x[tile]), torch.cos(y[tile]))) convert_element_type: "i32[u0]" = torch.ops.prims.convert_element_type.default(sigmoid, torch.int32); sigmoid = None out: "i32[s77]" = helion_language__tracing_ops__host_tensor('out') - store = helion_language_memory_ops_store(out, [block_size0], convert_element_type); out = block_size0 = convert_element_type = store = None + store = helion_language_memory_ops_store(out, [block_size_0], convert_element_type); out = block_size_0 = convert_element_type = store = None return None""", ) @@ -574,14 +574,14 @@ def all_ast_nodes(x, y): # Attribute: TensorAttributeType AttributeOrigin(value=SourceOrigin(location=), key='size') # Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=) for tile in hl.tile(out.size()): - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([y_size0, x_size1], torch.int32) SourceOrigin(location=) # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) - # BinOp: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # BinOp: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='x') # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([y_size0, x_size1], torch.int32) ArgumentOrigin(name='y') # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) out[tile] = x[tile] + y[tile] @@ -590,20 +590,20 @@ def all_ast_nodes(x, y): def device_ir(): # File: .../all_ast_nodes.py:146 in all_ast_nodes, code: out[tile] = x[tile] + y[tile] x: "i32[s17, s27]" = helion_language__tracing_ops__host_tensor('x') - block_size0: "Sym(u21)" = helion_language__tracing_ops__get_symnode('block_size0') - block_size1: "Sym(u22)" = helion_language__tracing_ops__get_symnode('block_size1') - load: "i32[u21, u22]" = helion_language_memory_ops_load(x, [block_size0, block_size1]); x = None + block_size_0: "Sym(u21)" = helion_language__tracing_ops__get_symnode('block_size_0') + block_size_1: "Sym(u22)" = helion_language__tracing_ops__get_symnode('block_size_1') + load: "i32[u21, u22]" = helion_language_memory_ops_load(x, [block_size_0, block_size_1]); x = None # File: .../all_ast_nodes.py:146 in all_ast_nodes, code: out[tile] = x[tile] + y[tile] y: "i32[s17, s27]" = helion_language__tracing_ops__host_tensor('y') - load_1: "i32[u21, u22]" = helion_language_memory_ops_load(y, [block_size0, block_size1]); y = None + load_1: "i32[u21, u22]" = helion_language_memory_ops_load(y, [block_size_0, block_size_1]); y = None # File: .../all_ast_nodes.py:146 in all_ast_nodes, code: out[tile] = x[tile] + y[tile] add: "i32[u21, u22]" = torch.ops.aten.add.Tensor(load, load_1); load = load_1 = None # File: .../all_ast_nodes.py:146 in all_ast_nodes, code: out[tile] = x[tile] + y[tile] out: "i32[s17, s27]" = helion_language__tracing_ops__host_tensor('out') - store = helion_language_memory_ops_store(out, [block_size0, block_size1], add); out = block_size0 = block_size1 = add = store = None + store = helion_language_memory_ops_store(out, [block_size_0, block_size_1], add); out = block_size_0 = block_size_1 = add = store = None return None""", ) @@ -629,50 +629,50 @@ def hl_zeros_usage(x: torch.Tensor): # Attribute: TensorAttributeType AttributeOrigin(value=SourceOrigin(location=), key='size') # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) for tile in hl.tile(out.size()): - # Call: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Attribute: CallableType(zeros) AttributeOrigin(value=GlobalOrigin(name='hl'), key='zeros') # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) # Attribute: LiteralType(torch.int32) AttributeOrigin(value=ArgumentOrigin(name='x'), key='dtype') # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') tmp = hl.zeros(tile, dtype=x.dtype) - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) tmp += x[tile] - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) tmp += x[tile] - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) - # Name: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Name: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) out[tile] = tmp return out def device_ir(): # File: .../basic_kernels.py:29 in hl_zeros_usage, code: tmp = hl.zeros(tile, dtype=x.dtype) - block_size0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size0') - block_size1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size1') - tmp: "i32[u0, u1]" = helion_language_creation_ops_full([block_size0, block_size1], 0, torch.int32) + block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') + block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1') + tmp: "i32[u0, u1]" = helion_language_creation_ops_full([block_size_0, block_size_1], 0, torch.int32) # File: .../basic_kernels.py:30 in hl_zeros_usage, code: tmp += x[tile] x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x') - load: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size0, block_size1]) + load: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, block_size_1]) # File: .../basic_kernels.py:30 in hl_zeros_usage, code: tmp += x[tile] tmp_1: "i32[u0, u1]" = torch.ops.aten.add.Tensor(tmp, load); tmp = load = None # File: .../basic_kernels.py:31 in hl_zeros_usage, code: tmp += x[tile] - load_1: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size0, block_size1]); x = None + load_1: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, block_size_1]); x = None # File: .../basic_kernels.py:31 in hl_zeros_usage, code: tmp += x[tile] tmp_2: "i32[u0, u1]" = torch.ops.aten.add.Tensor(tmp_1, load_1); tmp_1 = load_1 = None # File: .../basic_kernels.py:32 in hl_zeros_usage, code: out[tile] = tmp out: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('out') - store = helion_language_memory_ops_store(out, [block_size0, block_size1], tmp_2); out = block_size0 = block_size1 = tmp_2 = store = None + store = helion_language_memory_ops_store(out, [block_size_0, block_size_1], tmp_2); out = block_size_0 = block_size_1 = tmp_2 = store = None return None""", ) @@ -698,7 +698,7 @@ def hl_full_usage(x: torch.Tensor): # Attribute: TensorAttributeType AttributeOrigin(value=SourceOrigin(location=), key='size') # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) for tile in hl.tile(out.size()): - # Call: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Attribute: CallableType(full) AttributeOrigin(value=GlobalOrigin(name='hl'), key='full') # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) @@ -706,43 +706,43 @@ def hl_full_usage(x: torch.Tensor): # Attribute: LiteralType(torch.int32) AttributeOrigin(value=ArgumentOrigin(name='x'), key='dtype') # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') tmp = hl.full(tile, 1, dtype=x.dtype) - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) tmp += x[tile] - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) tmp += x[tile] - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) - # Name: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Name: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) out[tile] = tmp return out def device_ir(): # File: .../basic_kernels.py:40 in hl_full_usage, code: tmp = hl.full(tile, 1, dtype=x.dtype) - block_size0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size0') - block_size1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size1') - tmp: "i32[u0, u1]" = helion_language_creation_ops_full([block_size0, block_size1], 1, torch.int32) + block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') + block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1') + tmp: "i32[u0, u1]" = helion_language_creation_ops_full([block_size_0, block_size_1], 1, torch.int32) # File: .../basic_kernels.py:41 in hl_full_usage, code: tmp += x[tile] x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x') - load: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size0, block_size1]) + load: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, block_size_1]) # File: .../basic_kernels.py:41 in hl_full_usage, code: tmp += x[tile] tmp_1: "i32[u0, u1]" = torch.ops.aten.add.Tensor(tmp, load); tmp = load = None # File: .../basic_kernels.py:42 in hl_full_usage, code: tmp += x[tile] - load_1: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size0, block_size1]); x = None + load_1: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, block_size_1]); x = None # File: .../basic_kernels.py:42 in hl_full_usage, code: tmp += x[tile] tmp_2: "i32[u0, u1]" = torch.ops.aten.add.Tensor(tmp_1, load_1); tmp_1 = load_1 = None # File: .../basic_kernels.py:43 in hl_full_usage, code: out[tile] = tmp out: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('out') - store = helion_language_memory_ops_store(out, [block_size0, block_size1], tmp_2); out = block_size0 = block_size1 = tmp_2 = store = None + store = helion_language_memory_ops_store(out, [block_size_0, block_size_1], tmp_2); out = block_size_0 = block_size_1 = tmp_2 = store = None return None""", ) @@ -775,15 +775,15 @@ def pointwise_device_loop(x: torch.Tensor): # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') # Name: SymIntType(s27) GetItemOrigin(value=AttributeOrigin(value=ArgumentOrigin(name='x'), key='shape'), key=1) for tile_m in hl.tile(m): - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) # Name: TileIndexType(0) SourceOrigin(location=) # Name: TileIndexType(1) DeviceOrigin(location=) - # Call: TensorType([block_size0, block_size1], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.sigmoid) AttributeOrigin(value=GlobalOrigin(name='torch'), key='sigmoid') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') - # BinOp: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # BinOp: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') # Name: TileIndexType(0) SourceOrigin(location=) # Name: TileIndexType(1) DeviceOrigin(location=) @@ -794,9 +794,9 @@ def pointwise_device_loop(x: torch.Tensor): def subgraph_0(): # File: .../basic_kernels.py:53 in pointwise_device_loop, code: out[tile_n, tile_m] = torch.sigmoid(x[tile_n, tile_m] + 1) x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x') - block_size0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size0') - block_size1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size1') - load: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size0, block_size1]); x = None + block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') + block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1') + load: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, block_size_1]); x = None # File: .../basic_kernels.py:53 in pointwise_device_loop, code: out[tile_n, tile_m] = torch.sigmoid(x[tile_n, tile_m] + 1) add: "i32[u0, u1]" = torch.ops.aten.add.Tensor(load, 1); load = None @@ -807,7 +807,7 @@ def subgraph_0(): # File: .../basic_kernels.py:53 in pointwise_device_loop, code: out[tile_n, tile_m] = torch.sigmoid(x[tile_n, tile_m] + 1) convert_element_type: "i32[u0, u1]" = torch.ops.prims.convert_element_type.default(sigmoid, torch.int32); sigmoid = None out: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('out') - store = helion_language_memory_ops_store(out, [block_size0, block_size1], convert_element_type); out = block_size0 = block_size1 = convert_element_type = store = None + store = helion_language_memory_ops_store(out, [block_size_0, block_size_1], convert_element_type); out = block_size_0 = block_size_1 = convert_element_type = store = None return [] def device_ir(): @@ -845,12 +845,12 @@ def fn(x): # Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='x'), key='size') # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') for tile in hl.tile(x.size()): - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) - # Call: TensorType([block_size0, block_size1], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) # Attribute: TensorAttributeType AttributeOrigin(value=DeviceOrigin(location=), key='sin') - # Subscript: TensorType([block_size0, block_size1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) out[tile] = x[tile].sin() @@ -859,9 +859,9 @@ def fn(x): def device_ir(): # File: .../test_type_propagation.py:824 in fn, code: out[tile] = x[tile].sin() x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x') - block_size0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size0') - block_size1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size1') - load: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size0, block_size1]); x = None + block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') + block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1') + load: "i32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, block_size_1]); x = None # File: .../test_type_propagation.py:824 in fn, code: out[tile] = x[tile].sin() sin: "f32[u0, u1]" = torch.ops.aten.sin.default(load); load = None @@ -869,7 +869,7 @@ def device_ir(): # File: .../test_type_propagation.py:824 in fn, code: out[tile] = x[tile].sin() convert_element_type: "i32[u0, u1]" = torch.ops.prims.convert_element_type.default(sin, torch.int32); sin = None out: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('out') - store = helion_language_memory_ops_store(out, [block_size0, block_size1], convert_element_type); out = block_size0 = block_size1 = convert_element_type = store = None + store = helion_language_memory_ops_store(out, [block_size_0, block_size_1], convert_element_type); out = block_size_0 = block_size_1 = convert_element_type = store = None return None""", ) @@ -920,7 +920,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor): # Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=), key=0) # Name: SymIntType(s94) GetItemOrigin(value=SourceOrigin(location=), key=1) for tile_m, tile_n in hl.tile([m, n]): - # Call: TensorType([block_size0, block_size1], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) # Attribute: CallableType(zeros) AttributeOrigin(value=GlobalOrigin(name='hl'), key='zeros') # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') # List: SequenceType([TileIndexType(0), TileIndexType(1)]) DeviceOrigin(location=) @@ -935,24 +935,24 @@ def matmul(x: torch.Tensor, y: torch.Tensor): # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') # Name: SymIntType(s27) GetItemOrigin(value=SourceOrigin(location=), key=1) for tile_k in hl.tile(k): - # Call: TensorType([block_size0, block_size1], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.addmm) AttributeOrigin(value=GlobalOrigin(name='torch'), key='addmm') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') - # Name: TensorType([block_size0, block_size1], torch.float32) DeviceOrigin(location=) - # Subscript: TensorType([block_size0, block_size2], torch.float32) DeviceOrigin(location=) + # Name: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_2], torch.float32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x') # Name: TileIndexType(0) SourceOrigin(location=) # Name: TileIndexType(2) DeviceOrigin(location=) - # Subscript: TensorType([block_size2, block_size1], torch.float32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_2, block_size_1], torch.float32) DeviceOrigin(location=) # Name: TensorType([y_size0, y_size1], torch.float32) ArgumentOrigin(name='y') # Name: TileIndexType(2) DeviceOrigin(location=) # Name: TileIndexType(1) SourceOrigin(location=) acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) - # Subscript: TensorType([block_size0, block_size1], torch.float32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) # Name: TensorType([x_size0, y_size1], torch.float32) SourceOrigin(location=) # Name: TileIndexType(0) SourceOrigin(location=) # Name: TileIndexType(1) SourceOrigin(location=) - # Name: TensorType([block_size0, block_size1], torch.float32) DeviceOrigin(location=) + # Name: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) out[tile_m, tile_n] = acc return out @@ -960,13 +960,13 @@ def subgraph_0(arg0_1: "f32[u1, u2]"): # File: .../matmul.py:20 in matmul, code: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) x: "f32[s77, s27]" = helion_language__tracing_ops__host_tensor('x') sym_size_int: "Sym(u1)" = torch.ops.aten.sym_size.int(arg0_1, 0) - block_size2: "Sym(u3)" = helion_language__tracing_ops__get_symnode('block_size2') - load: "f32[u1, u3]" = helion_language_memory_ops_load(x, [sym_size_int, block_size2]); x = sym_size_int = None + block_size_2: "Sym(u3)" = helion_language__tracing_ops__get_symnode('block_size_2') + load: "f32[u1, u3]" = helion_language_memory_ops_load(x, [sym_size_int, block_size_2]); x = sym_size_int = None # File: .../matmul.py:20 in matmul, code: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) y: "f32[s17, s94]" = helion_language__tracing_ops__host_tensor('y') sym_size_int_1: "Sym(u2)" = torch.ops.aten.sym_size.int(arg0_1, 1) - load_1: "f32[u3, u2]" = helion_language_memory_ops_load(y, [block_size2, sym_size_int_1]); y = block_size2 = sym_size_int_1 = None + load_1: "f32[u3, u2]" = helion_language_memory_ops_load(y, [block_size_2, sym_size_int_1]); y = block_size_2 = sym_size_int_1 = None # File: .../matmul.py:20 in matmul, code: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) acc: "f32[u1, u2]" = torch.ops.aten.addmm.default(arg0_1, load, load_1); arg0_1 = load = load_1 = None @@ -974,9 +974,9 @@ def subgraph_0(arg0_1: "f32[u1, u2]"): def device_ir(): # File: .../matmul.py:18 in matmul, code: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) - block_size0: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size0') - block_size1: "Sym(u2)" = helion_language__tracing_ops__get_symnode('block_size1') - acc: "f32[u1, u2]" = helion_language_creation_ops_full([block_size0, block_size1], 0.0, torch.float32) + block_size_0: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_0') + block_size_1: "Sym(u2)" = helion_language__tracing_ops__get_symnode('block_size_1') + acc: "f32[u1, u2]" = helion_language_creation_ops_full([block_size_0, block_size_1], 0.0, torch.float32) # File: .../matmul.py:19 in matmul, code: for tile_k in hl.tile(k): _for_loop = helion_language__tracing_ops__for_loop(0, [acc]) @@ -985,7 +985,7 @@ def device_ir(): # File: .../matmul.py:21 in matmul, code: out[tile_m, tile_n] = acc out: "f32[s77, s94]" = helion_language__tracing_ops__host_tensor('out') - store = helion_language_memory_ops_store(out, [block_size0, block_size1], _phi); out = block_size0 = block_size1 = _phi = store = None + store = helion_language_memory_ops_store(out, [block_size_0, block_size_1], _phi); out = block_size_0 = block_size_1 = _phi = store = None return None""", )