8000 Support persistent reductions by jansel · Pull Request #10 · pytorch-labs/helion · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Support persistent reductions #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/template_via_closure.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def check(n: int, k: int, m: int) -> None:
x = torch.randn([n, k], device="cuda", dtype=torch.float16)
y = torch.randn([k, m], device="cuda", dtype=torch.float16)
bias = torch.randn([1, m], device="cuda", dtype=torch.float16)
# The epilogue can use the captured bias tensor that is implicitly lifted to an arg
# The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg
result = matmul_with_epilogue(x, y, lambda acc, tile: torch.relu(acc + bias[tile]))
torch.testing.assert_close(
result,
Expand Down
56 changes: 39 additions & 17 deletions helion/_compiler/compile_environment.py
# SLoc("create_block_var", current_location().format()),
Original file line number Diff line number Diff line change
Expand Up @@ -76,25 +76,20 @@ def finalize_config_spec(self) -> None:
self.config_spec.block_size_specs, shape
)

def allocate_block_size(self, numel: int | torch.SymInt) -> int:
def allocate_block_size(
self, size: int | torch.SymInt, *, reduction: bool = False
) -> int:
idx = len(self.block_sizes)
if isinstance(numel, torch.SymInt):
numel_expr = numel._sympy_()
else:
numel_expr = sympy.sympify(numel)
with self.shape_env.ignore_fresh_unbacked_symbols():
sym = self.shape_env.create_unbacked_symint()
assert isinstance(sym._sympy_(), sympy.Symbol)
self.block_sizes.append(
info := BlockSizeInfo(
block_size_idx=idx,
numel=numel_expr,
var=sym,
size=size,
var=self.create_block_var(
f"block_size_{idx}" if not reduction else f"rdim_{idx}"
),
reduction=reduction,
)
)
self.debug_shape_renames[sym._sympy_()] = sympy.Symbol(
info.name(), integer=True
)

from .host_function import HostFunction
from .host_function import SymbolOrigin
Expand All @@ -104,6 +99,31 @@ def allocate_block_size(self, numel: int | torch.SymInt) -> int:
)
return idx

def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInfo:
for rdim in self.block_sizes:
if rdim.reduction and rdim.size == size:
return rdim
rdim_idx = self.allocate_block_size(size, reduction=True)
return self.block_sizes[rdim_idx]

def create_block_var(self, debug_name: str) -> torch.SymInt:
with self.shape_env.ignore_fresh_unbacked_symbols():
sym = self.shape_env.create_unbacked_symint()
# self.shape_env.guards.append(
# ShapeGuard(
# sympy.Ne(sym._sympy_(), 0),
# True,
# )
# )
# TODO(jansel): I was hoping the above would work, seems like some decomps require concrete values
# to determine zeroness. Figure out a better way to do this.
# pyre-ignore[29]
self.shape_env.var_to_val[sym._sympy_()] = sympy.Integer(64)
assert isinstance(sym._sympy_(), sympy.Symbol)
self.debug_shape_renames[sym._sympy_()] = sympy.Symbol(debug_name, integer=True)
return sym

def to_fake(self, obj: object, origin: Origin) -> object:
if isinstance(obj, torch.Tensor):
return self._to_fake_tensor(obj, origin.to_source())
Expand Down Expand Up @@ -233,15 +253,17 @@ class BlockSizeInfo(typing.NamedTuple):
"""

block_size_idx: int
numel: sympy.Expr
size: torch.SymInt | int
var: torch.SymInt
reduction: bool

@property
def numel(self) -> sympy.Expr:
return _to_sympy(self.size)

def symbol(self) -> sympy.Symbol:
return self.var._sympy_()

def name(self) -> str:
return f"block_size{self.block_size_idx}"


def warning(warning: exc.BaseWarning | type[exc.BaseWarning]) -> None:
CompileEnvironment.current().errors.add(warning)
Expand Down
2 changes: 1 addition & 1 deletion helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(self, name: str, config: Config) -> None:
self.dce_vars: list[str] = []

from .indexing_strategy import IndexingStrategy
from .tile_strategy import TileStrategyDispatch
from .tile_dispatch import TileStrategyDispatch

self.tile_strategy: TileStrategyDispatch = TileStrategyDispatch(self, config)
self.indexing_strategy: IndexingStrategy = IndexingStrategy.select(config)
Expand Down
1 change: 0 additions & 1 deletion helion/_compiler/host_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def __init__(self, fn: types.FunctionType, fake_args: list[object]) -> None:
env.errors.raise_if_errors()
env.finalize_config_spec()
self.device_ir = lower_to_device_ir(self)

# TODO(jansel): assert we don't have any extra decorators
# TODO(jansel): check type annotations for hl.constexpr/hl.specialize

Expand Down
22 changes: 22 additions & 0 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .ast_extension import expr_from_string
from .compile_environment import CompileEnvironment
from .host_function import HostFunction
from .tile_strategy import TileStrategy
from .variable_origin import BlockSizeOrigin

if TYPE_CHECKING:
Expand Down Expand Up @@ -172,6 +173,15 @@ def compute_shape(
output_size.append(k)
else:
output_size.append(1)
elif isinstance(k, slice) and str(k) == "slice(None, None, None)":
size = input_size.popleft()
if size != 1:
rdim = CompileEnvironment.current().allocate_reduction_dimension(
size
)
output_size.append(rdim.var)
else:
output_size.append(1)
else:
raise exc.InvalidIndexingType(k)
assert len(input_size) == 0, "invalid subscript"
Expand Down Expand Up @@ -211,6 +221,18 @@ def create(
else:
val = state.device_function.literal_expr(k)
index_values.append(f"tl.full([1], {val}, {dtype}){expand}")
elif isinstance(k, slice) and str(k) == "slice(None, None, None)":
expand = tile_strategy.expand_str(output_size, output_idx)
if fake_value.size(len(index_values)) != 1:
block_idx = TileStrategy.get_block_index(output_size[output_idx])
assert block_idx is not None
index_var = tile_strategy.index_var(block_idx)
index_values.append(f"({index_var}){expand}")
if mask := tile_strategy.mask_var(block_idx):
mask_values.setdefault(f"({mask}){expand}")
else:
index_values.append(f"tl.zeros([1], {dtype}){expand}")
output_idx += 1
else:
raise exc.InvalidIndexingType(k)
assert len(output_size) == output_idx
Expand Down
109 changes: 93 additions & 16 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import ast
import contextlib
import dataclasses
import functools
from operator import getitem
Expand Down Expand Up @@ -39,12 +40,14 @@
from .ast_extension import statement_from_string
from .compile_environment import CompileEnvironment
from .tile_strategy import TileStrategy
from .tile_strategy import TileStrategyDispatch

if TYPE_CHECKING:
from collections.abc import Iterator

from .. import Config
from .device_function import DeviceFunction
from .generate_ast import GenerateAST
from .tile_dispatch import TileStrategyDispatch

CodegenHandler = Callable[["GraphInterpreter", torch.fx.Node], object]

Expand All @@ -63,12 +66,7 @@ def prepare_graph_lowerings(gm: torch.fx.GraphModule) -> None:
"output",
}, node.op
if node.op == "call_function":
prior_buffers = len(graph_lowering.buffers)
node.meta["lowering"] = prepare_node_lowering(graph_lowering, node)
if len(graph_lowering.buffers) > prior_buffers + 1:
raise InductorLoweringError(
f"Lowering {node.op} resulted in {len(graph_lowering.buffers) - prior_buffers} buffers, expected 1."
)


def prepare_node_lowering(
Expand Down Expand Up @@ -117,6 +115,7 @@ def convert_arg(arg: Node) -> TensorBox:
)
)

prior_buffers = len(graph_lowering.buffers)
input_names: list[str] = []
result = graph_lowering.call_function(
# pyre-ignore[6]
Expand All @@ -133,6 +132,16 @@ def convert_arg(arg: Node) -> TensorBox:
raise InductorLoweringError(
f"Lowering {node.target} returned buffer type {type(buffer)}, expected ComputedBuffer: {buffer}"
)

new_buffers = graph_lowering.buffers[prior_buffers:]
if len(new_buffers) != 1:
# TODO(jansel): handle multiple buffers more generally
raise InductorLoweringError(
f"Lowering {node.op} resulted in {len(new_buffers)} buffers, expected 1."
)

assert new_buffers[0] is buffer
assert isinstance(buffer, ComputedBuffer)
if isinstance(buffer.data, Pointwise):
return PointwiseLowering(buffer, input_names)
if isinstance(buffer.data, Reduction):
Expand Down Expand Up @@ -169,21 +178,33 @@ def input_asts(self, ctx: GraphInterpreter, node: torch.fx.Node) -> list[ast.AST
assert len(input_asts) == len(self.input_names)
return input_asts

@staticmethod
def input_fake_tensors(node: torch.fx.Node) -> list[torch.Tensor]:
def visit(n: torch.fx.Node) -> torch.fx.Node:
if isinstance(val := n.meta["val"], torch.Tensor):
result.append(val)
return n

result: list[torch.Tensor] = []
map_arg((node.args, node.kwargs), visit)
return result

def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
raise NotImplementedError(
f"codegen not implemented for {type(self).__name__}: {self.buffer}"
)


@functools.cache
def dummy_gm() -> torch.fx.GraphModule:
return torch.fx.symbolic_trace(lambda: None)


class PointwiseLowering(InductorLowering):
def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
@contextlib.contextmanager
def install_kernel_handlers(
self, ctx: GraphInterpreter, node: torch.fx.Node
) -> Iterator[None]:
with (
inductor_config.patch("triton.codegen_upcast_to_fp32", False),
inductor_config.patch(
{
"triton.codegen_upcast_to_fp32": False,
"split_reductions": False,
}
),
# pyre-ignore[19]
V.set_graph_handler(
GraphLowering(
Expand All @@ -201,6 +222,17 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
TritonKernel({}, features=SIMDKernelFeatures([], sympy.S.One))
),
):
yield


@functools.cache
def dummy_gm() -> torch.fx.GraphModule:
return torch.fx.symbolic_trace(lambda: None)


class PointwiseLowering(InductorLowering):
def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
with self.install_kernel_handlers(ctx, node):
indices = [
sympy.Symbol(f"i{n}") for n in range(len(self.buffer.data.ranges))
]
Expand All @@ -209,7 +241,52 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:


class ReductionLowering(InductorLowering):
pass
def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
reduction = self.buffer.data
assert isinstance(reduction, Reduction)
indices = [sympy.Symbol(f"i{n}") for n in range(len(reduction.ranges))]
reduction_indices = [
sympy.Symbol(f"i{n}")
for n in range(len(indices), len(indices) + len(reduction.reduction_ranges))
]
with self.install_kernel_handlers(ctx, node):
# codegen the pointwise part before reduction
output_name = _unpack_opsvalue(
self.buffer.data.inner_fn(indices, reduction_indices)
)

reduction_ranges = reduction.reduction_ranges
if len(reduction_ranges) != 1:
# TODO(jansel): can this happen?
raise NotImplementedError("multiple reduction dimensions")
reduction_var = reduction_ranges[0]
assert isinstance(reduction_var, sympy.Symbol)

block_idx = TileStrategy.get_block_index(reduction_var)
assert block_idx is not None
strategy = ctx.cg.device_function.tile_strategy.get_reduction_strategy(
block_idx
)

inputs = self.input_fake_tensors(node)
if len(inputs) != 1:
# TODO(jansel): combine multiple inputs into a single fake value
raise NotImplementedError("reductions with >1 input")

# TODO(jansel): find a better way to get dim
(dim,) = [
i
for i, v in enumerate(inputs[0].shape)
if TileStrategy.get_block_index(v) == block_idx
]

return strategy.codegen_reduction(
output_name,
reduction.reduction_type,
dim,
inputs[0],
node.meta["val"],
)


@dataclasses.dataclass
Expand Down
Loading
0