8000 Support data-dependent loop bounds by jansel · Pull Request #81 · pytorch-labs/helion · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Support data-dependent loop bounds #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions helion/_compiler/compile_environment.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def finalize_config_spec(self) -> None:

def allocate_block_size(
self,
size: int | torch.SymInt,
size: int | torch.SymInt | None,
*,
reduction: bool = False,
source: BlockSizeSource,
Expand Down Expand Up @@ -302,15 +302,27 @@ class BlockSizeInfo(typing.NamedTuple):
"""

block_size_idx: int
size: torch.SymInt | int
size: torch.SymInt | int | None
var: torch.SymInt
reduction: bool
block_size_source: BlockSizeSource

@property
def numel(self) -> sympy.Expr:
assert self.size is not None
return _to_sympy(self.size)

def known_multiple(self, block_size: int | torch.SymInt) -> bool:
if block_size == 1:
return True
if self.size is None:
return False
return CompileEnvironment.current().known_multiple(self.numel, block_size)

def size_hint(self) -> int:
assert self.size is not None
return CompileEnvironment.current().size_hint(self.size)

def symbol(self) -> sympy.Symbol:
return self.var._sympy_()

Expand Down
42 changes: 33 additions & 9 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .source_location import current_location
from .tile_index_proxy import CheckForIndexCalls
from .tile_index_proxy import TileIndexProxy
from .type_propagation import CallableType
from .type_propagation import GridIndexType
from .type_propagation import IterType
from .type_propagation import SequenceType
Expand Down Expand Up @@ -163,7 +164,7 @@ def kwargs(self) -> dict[str, object]:
}

def codegen(self, state: CodegenState) -> list[object]:
args = state.ast_args[1]
args = state.ast_args[-1]
assert isinstance(args, list)
assert all(isinstance(x, ast.AST) for x in args)
with state.codegen.add_device_loop(
Expand Down Expand Up @@ -294,7 +295,7 @@ def build_rolled_reductions(self) -> None:
graph_to_info[graph_id] = reduction_info
env.config_spec.reduction_loop_specs.append(
ReductionLoopSpec(
size_hint=env.size_hint(rdim.size),
size_hint=rdim.size_hint(),
# TODO(jansel): we should add support for rolling multiple dims at once
allow_loop=allow_loop and first,
)
Expand Down Expand Up @@ -412,6 +413,24 @@ def should_become_arg(value: object) -> bool:
return origin.is_device()
return True

def _extract_tile_begin_end(self, for_node: ast.For) -> tuple[object, object]:
call_node = for_node.iter
assert isinstance(call_node, ast.Call)
func_node = call_node.func
assert isinstance(func_node, ExtendedAST)
func_type = func_node._type_info
assert isinstance(func_type, CallableType)
assert func_type.value in (hl.tile, hl.grid)
args = call_node.args
assert len(args) >= 1
if len(args) == 1:
begin = None
end = self.visit(args[0])
else:
begin = self.visit(args[0])
end = self.visit(args[1])
return begin, end

def visit_For(self, node: ast.For) -> None:
assert isinstance(node, ExtendedAST)
assert not node.orelse
Expand All @@ -432,6 +451,16 @@ def visit_For(self, node: ast.For) -> None:
}
)
outputs: LiftTensorArgs | None = None
begin, end = self._extract_tile_begin_end(node)
if isinstance(inner_type, SequenceType):
iter_vars = inner_type.unpack()
if begin is None:
begin = [0] * len(iter_vars)
else:
iter_vars = [inner_type]
begin = [0] if begin is None else [begin]
end = [end]
assert all(isinstance(x, (TileIndexType, GridIndexType)) for x in iter_vars)

def run_subgraph(*args: object) -> list[object]:
nonlocal outputs
Expand Down Expand Up @@ -461,20 +490,15 @@ def run_subgraph(*args: object) -> list[object]:
graph = proxy_tensor.make_fx(
run_subgraph, decomposition_table=select_decomp_table()
)(*inputs.get_tensor_args())
if isinstance(inner_type, SequenceType):
iter_vars = inner_type.unpack()
else:
iter_vars = [inner_type]
assert all(
isinstance(x, (TileIndexType, GridIndexType)) for x in iter_vars
)
graph_idx = self.device_ir.add_graph(
graph,
ForLoopGraphInfo,
block_indices=[x.block_size_idx for x in iter_vars],
)
args = (
graph_idx,
begin,
end,
inputs.get_tensor_args(),
)
proxy_out = tracer.create_proxy(
Expand Down
15 changes: 9 additions & 6 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class BlockPtrIndexingStrategy(IndexingStrategy):
def codegen_load(
self, state: CodegenState, fake_tensor: torch.Tensor, subscript: list[object]
) -> ast.AST:
if not BlockedSubscriptIndexing.is_supported(state, subscript):
if not BlockedSubscriptIndexing.is_supported(state, fake_tensor, subscript):
return PointerIndexingStrategy().codegen_load(state, fake_tensor, subscript)
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
return indexing.reshape_load(
Expand All @@ -99,7 +99,7 @@ def codegen_load(
def codegen_store(
self, state: CodegenState, fake_tensor: torch.Tensor, subscript: list[object]
) -> ast.AST:
if not BlockedSubscriptIndexing.is_supported(state, subscript):
if not BlockedSubscriptIndexing.is_supported(state, fake_tensor, subscript):
return PointerIndexingStrategy().codegen_store(
state, fake_tensor, subscript
)
Expand All @@ -117,7 +117,7 @@ class TensorDescriptorIndexingStrategy(IndexingStrategy):
def codegen_load(
self, state: CodegenState, fake_tensor: torch.Tensor, subscript: list[object]
) -> ast.AST:
if not BlockedSubscriptIndexing.is_supported(state, subscript):
if not BlockedSubscriptIndexing.is_supported(state, fake_tensor, subscript):
return PointerIndexingStrategy().codegen_load(state, fake_tensor, subscript)
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
return indexing.reshape_load(
Expand All @@ -130,7 +130,7 @@ def codegen_load(
def codegen_store(
self, state: CodegenState, fake_tensor: torch.Tensor, subscript: list[object]
) -> ast.AST:
if not BlockedSubscriptIndexing.is_supported(state, subscript):
if not BlockedSubscriptIndexing.is_supported(state, fake_tensor, subscript):
return PointerIndexingStrategy().codegen_store(
state, fake_tensor, subscript
)
Expand Down Expand Up @@ -375,7 +375,9 @@ def reshape_store(self, state: CodegenState, node: ast.AST) -> ast.AST:
return expr_from_string(f"tl.reshape(node, {shape})" 6D40 , node=node)

@staticmethod
def is_supported(state: CodegenState, index: list[object]) -> bool:
def is_supported(
state: CodegenState, fake_tensor: torch.Tensor, index: list[object]
) -> bool:
for k in index:
if isinstance(k, torch.SymInt):
symbol = k._sympy_()
Expand All @@ -390,7 +392,8 @@ def is_supported(state: CodegenState, index: list[object]) -> bool:
if isinstance(k, torch.Tensor):
# indirect loads don't work with block_ptr
return False
return True
output_shape = SubscriptIndexing.compute_shape(fake_tensor, index)
return len(output_shape) != 0

def validate(self) -> None:
n = self.ndim
Expand Down
28 changes: 22 additions & 6 deletions helion/_compiler/roll_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
self.seen: set[torch.fx.Node] = set()
self.available: set[torch.fx.Node] = set()
self.graphs_added: list[int] = []
self._size_node: torch.fx.Node | None = None

def is_reduction(self, node: torch.fx.Node) -> bool:
"""Check if a node is a reduction"""
Expand All @@ -71,7 +72,7 @@ def should_go_in_inner_graph(self, node: torch.fx.Node) -> bool:

if node.target in (_for_loop, _if):
if node.target is _for_loop:
graph_id, _ = node.args
graph_id, *_ = node.args
else:
_, graph_id, _ = node.args
assert isinstance(graph_id, int)
Expand Down Expand Up @@ -111,6 +112,21 @@ def should_go_in_inner_graph(self, node: torch.fx.Node) -> bool:

return num_rdims > 0

def size_node(self, meta: dict[str, object]) -> torch.fx.Node:
"""Create a node that represents the size of the reduction dimension"""
if self._size_node is not None:
return self._size_node
self._size_node = node = self.outer_graph.call_function(
_get_symnode,
(f"rdim{self.rdim.block_size_idx}",),
{},
)
node.meta.update(meta)
node.meta["val"] = self.rdim.size
# pyre-ignore[6]
node.meta["lowering"] = APIFuncLowering(_get_symnode)
return node

def start_new_graph(self) -> None:
if self.inner_count == 0:
return
Expand All @@ -130,15 +146,15 @@ def start_new_graph(self) -> None:
)
self.graphs_added.append(graph_id)

output_node = self.outer_graph.call_function(
_for_loop,
(graph_id, self.inner_args),
{},
)
location_meta = {
"location": next(iter(inner_nodes)).meta["location"],
"stack_trace": next(iter(inner_nodes)).meta["stack_trace"],
}
output_node = self.outer_graph.call_function(
_for_loop,
(graph_id, [0], [self.size_node(location_meta)], self.inner_args),
{},
)
output_node.meta.update(location_meta)
output_node.meta["val"] = [n.meta["val"] for n in outputs]
assert is_api_func(_for_loop)
Expand Down
64 changes: 43 additions & 21 deletions helion/_compiler/tile_strategy.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,11 @@ def codegen_grid(self, state: CodegenState) -> None:
state.add_statement(
f"{index_var} = {offset_var} + tl.zeros([1], {dtype})"
)
if hasattr(self, "_setup_mask"):
mask_statement = self._setup_mask( # pyre-ignore[16]
state, block_idx, block_size, index_var
)
if mask_statement is not None:
state.add_statement(mask_statement)
mask_statement = self._setup_mask( # pyre-ignore[16]
state, block_idx, block_size, index_var, numel
)
if mask_statement is not None:
state.add_statement(mask_statement)
pids.append(ProgramID(pid_var, block_size_var, numel))
pids.codegen(state)

Expand All @@ -414,20 +413,32 @@ def select_pid_strategy(self) -> ProgramIDs:
return GridProgramIDs()
return VirtualProgramIDs()

def _to_ast(self, x: object) -> ast.AST:
if isinstance(x, ast.AST):
return x
if isinstance(x, int):
return expr_from_string(repr(x))
if isinstance(x, sympy.Expr):
from .device_function import DeviceFunction

return expr_from_string(DeviceFunction.current().sympy_expr(x))
raise NotImplementedError(f"{type(x)} is not implemented.")

def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
# TODO(jansel): refactor this to share code with codegen_grid
block_indices = self.block_indices
env = CompileEnvironment.current()
device_function = state.device_function
dtype = env.triton_index_type()
block_sizes = self.block_size
body = innermost_body = []
for_node: ast.For | None = None
assert len(block_sizes) == len(block_indices)
for block_idx, block_size in self._reorder(
[*zip(block_indices, block_sizes, strict=True)]
_, begins, ends, _ = state.ast_args
assert isinstance(begins, list)
assert isinstance(ends, list)
for block_idx, block_size, begin, end in self._reorder(
[*zip(block_indices, block_sizes, begins, ends, strict=True)]
):
numel = env.block_sizes[block_idx].numel
offset_var = self.offset_var(block_idx)
index_var = self.index_var(block_idx)
if block_size != 1:
Expand All @@ -445,7 +456,9 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
ast.For,
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
iter=expr_from_string(
f"range(0, ({device_function.sympy_expr(numel)}), {block_size_var})"
f"range(begin, end, {block_size_var})",
begin=self._to_ast(begin),
end=self._to_ast(end),
),
body=body,
orelse=[],
Expand All @@ -457,12 +470,11 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
f"{index_var} = {offset_var} + tl.arange(0, ({block_size_var})).to({dtype})"
),
]
if hasattr(self, "_setup_mask"):
mask_statement = self._setup_mask( # pyre-ignore[16]
state, block_idx, block_size, index_var
)
if mask_statement is not None:
extra_body.append(mask_statement)
mask_statement = self._setup_mask( # pyre-ignore[16]
state, block_idx, block_size, index_var, end
)
if mask_statement is not None:
extra_body.append(mask_statement)
body[:] = [*extra_body, *body]
body = [for_node]
assert for_node is not None
Expand Down Expand Up @@ -501,17 +513,20 @@ def _setup_mask(
block_idx: int,
block_size: SymIntLike,
index_var: str,
end: object,
) -> ast.stmt | None:
env = CompileEnvironment.current()
numel = env.block_sizes[block_idx].numel
if block_size == 1 or env.known_multiple(numel, block_size):
if (
CompileEnvironment.current()
.block_sizes[block_idx]
.known_multiple(block_size)
):
self.mask_vars[block_idx] = None
return None
self.mask_vars[block_idx] = mask_var = self.fn.new_var(
f"mask_{block_idx}", dce=True
)
return statement_from_string(
f"{mask_var} = ({index_var} < ({state.device_function.sympy_expr(numel)}))"
f"{mask_var} = ({index_var}) < end", end=self._to_ast(end)
)

def select_pid_strategy(self) -> ProgramIDs:
Expand All @@ -537,6 +552,13 @@ def __init__(
def mask_var(self, block_idx: int) -> str | None:
return None

def _setup_mask(
self,
*args: object,
**kwargs: object,
) -> None:
return None


class CompactedShape(NamedTuple):
size_str: str
Expand Down
8 changes: 6 additions & 2 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,13 +948,17 @@ def proxy(self) -> object:

@staticmethod
def allocate(
numels: list[int | torch.SymInt], origin: Origin
numels: list[int | torch.SymInt | None], origin: Origin
) -> list[TileIndexType]:
env = CompileEnvironment.current()
spec_id = len(env.config_spec.block_size_specs)
env.config_spec.block_size_specs.append(
BlockSizeSpec(
size_hints=[env.size_hint(x) for x in numels],
size_hints=[
# For data-dependent sizes, use max block size of 8192
env.size_hint(x) if x is not None else 8192
for x in numels
],
allow_flattened=len(numels) > 1,
allow_reorder=len(numels) > 1,
# TOOD(jansel): implement N-D l2 grouping
Expand Down
2 changes: 1 addition & 1 deletion helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class NestedGridLoop(BaseError):


class RankMismatch(BaseError):
message = "Expected rank {0} tensor, but got rank {1} tensor."
message = "Expected ndim={0}, but got ndim={1}"


class InvalidIndexingType(BaseError):
Expand Down
Loading
Loading
0