8000 Add `hl.grid(...)` support by yf225 · Pull Request #59 · pytorch-labs/helion · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add hl.grid(...) support #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ def from_config_assert(self, config: Config) -> int | torch.SymInt:
def is_flattened(self, config: Config) -> bool:
return self.block_size_source.is_flattened(config)

def is_grid(self) -> bool:
return self.block_size_source.is_grid()

def get_order(self, config: Config, count: int) -> list[int]:
return self.block_size_source.get_order(config, count)

Expand All @@ -330,6 +333,9 @@ def from_config(self, config: Config) -> int | torch.SymInt | None:
def is_flattened(self, config: Config) -> bool:
return False

def is_grid(self) -> bool:
return False

def get_order(self, config: Config, count: int) -> list[int]:
return [*range(count)]

Expand All @@ -348,6 +354,15 @@ def from_config(self, config: Config) -> int | torch.SymInt:
return self.value


@dataclasses.dataclass
class GridBlockSizeSource(BlockSizeSource):
def from_config(self, config: Config) -> int:
raise NotImplementedError

def is_grid(self) -> bool:
return True


@dataclasses.dataclass
class LoopSpecBlockSizeSource(BlockSizeSource):
loop_spec: int
Expand Down
5 changes: 4 additions & 1 deletion helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .source_location import current_location
from .tile_index_proxy import CheckForIndexCalls
from .tile_index_proxy import TileIndexProxy
from .type_propagation import GridIndexType
from .type_propagation import IterType
from .type_propagation import SequenceType
from .type_propagation import TensorType
Expand Down Expand Up @@ -464,7 +465,9 @@ def run_subgraph(*args: object) -> list[object]:
iter_vars = inner_type.unpack()
else:
iter_vars = [inner_type]
assert all(isinstance(x, TileIndexType) for x in iter_vars)
assert all(
isinstance(x, (TileIndexType, GridIndexType)) for x in iter_vars
)
graph_idx = self.device_ir.add_graph(
graph,
ForLoopGraphInfo,
Expand Down
38 changes: 32 additions & 6 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,13 @@ def compute_shape(
if isinstance(symbol, sympy.Symbol):
origin = HostFunction.current().symbol_to_origin.get(symbol.name)
if origin and isinstance(origin.origin, BlockSizeOrigin):
if tensor.size(tensor.ndim - len(input_size) - 1) != 1:
if (
CompileEnvironment.current()
.block_sizes[origin.origin.block_size_idx]
.is_grid()
):
pass
elif tensor.size(tensor.ndim - len(input_size) - 1) != 1:
output_size.append(k)
else:
output_size.append(1)
Expand Down Expand Up @@ -200,6 +206,7 @@ def create(
mask_values = {}
output_size = SubscriptIndexing.compute_shape(fake_value, index)
dtype = CompileEnvironment.current().triton_index_type()
first_non_grid_index = 0
for n, k in enumerate(index):
if k is None:
output_idx += 1
Expand All @@ -210,8 +217,18 @@ def create(
origin = None
if isinstance(symbol, sympy.Symbol):
origin = HostFunction.current().symbol_to_origin.get(symbol.name)
expand = tile_strategy.expand_str(output_size, output_idx)
if origin and isinstance(origin.origin, BlockSizeOrigin):
if (
CompileEnvironment.current()
.block_sizes[origin.origin.block_size_idx]
.is_grid()
):
first_non_grid_index = n + 1
expand = tile_strategy.expand_str(output_size, output_idx)
else:
expand = tile_strategy.expand_str(
output_size, output_idx - first_non_grid_index
)
index_var = state.codegen.index_var(origin.origin.block_size_idx)
i = len(index_values)
index_values.append(f"({index_var}){expand}")
Expand All @@ -221,10 +238,15 @@ def create(
mask_values.setdefault(f"({mask}){expand}")
output_idx += 1
else:
expand = tile_strategy.expand_str(
output_size, output_idx - first_non_grid_index
)
val = state.device_function.literal_expr(k)
index_values.append(f"tl.full([1], {val}, {dtype}){expand}")
elif isinstance(k, slice) and str(k) == "slice(None, None, None)":
expand = tile_strategy.expand_str(output_size, output_idx)
expand = tile_strategy.expand_str(
output_size, output_idx - first_non_grid_index
)
size = fake_value.size(len(index_values))
if size != 1:
env = CompileEnvironment.current()
Expand All @@ -238,21 +260,25 @@ def create(
index_values.append(f"tl.zeros([1], {dtype}){expand}")
output_idx += 1
elif isinstance(k, torch.Tensor) and k.ndim == 1:
expand = tile_strategy.expand_str(output_size, output_idx)
expand = tile_strategy.expand_str(
output_size, output_idx - first_non_grid_index
)
ast_index = state.ast_args[1]
assert isinstance(ast_index, (list, tuple))
assert len(ast_index) == len(index)
index_var = state.codegen.lift(ast_index[n]).id
index_values.append(f"({index_var}){expand}")
if (
block_idx := TileStrategy.get_block_index(output_size[output_idx])
block_idx := TileStrategy.get_block_index(
output_size[output_idx - first_non_grid_index]
)
) is not None:
if mask := state.codegen.mask_var(block_idx):
mask_values.setdefault(f"({mask}){expand}")
output_idx += 1
else:
raise exc.InvalidIndexingType(k)
assert len(output_size) == output_idx
assert len(output_size) == output_idx - first_non_grid_index
assert len(index_values) == fake_value.ndim

index_expr = []
Expand Down
9 changes: 8 additions & 1 deletion helion/_compiler/tile_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from helion._compiler.tile_strategy import DeviceGridState
from helion._compiler.tile_strategy import DeviceLoopState
from helion._compiler.tile_strategy import FlattenedTileStrategy
from helion._compiler.tile_strategy import NDGridTileStrategy
from helion._compiler.tile_strategy import NDTileStrategy
from helion._compiler.tile_strategy import TileStrategy

< 6D4E a href="#diff-65c1766b0017b0e330706a0e7c746a9605ec0bb9b74dec3ac54a2035df25876d" id="expand-down-link-8-diff-65c1766b0017b0e330706a0e7c746a9605ec0bb9b74dec3ac54a2035df25876d" class="js-expand directional-expander" aria-label="Expand Down" data-url="/pytorch-labs/helion/blob_excerpt/12d5588a6ba1758b952f43fdc47c4c0ebe2c9b2c?context=pull_request&diff=unified&direction=down&in_wiki_context&last_left=20&last_right=21&left=60&left_hunk_size=7&mode=100644&path=helion%2F_compiler%2Ftile_dispatch.py&pull_request_id=2530222800&right=61&right_hunk_size=13" data-left-range="21-39" data-right-range="22-40"> Expand Down Expand Up @@ -60,7 +61,13 @@ def _add_loop_strategy(
env = CompileEnvironment.current()
block_size_infos = [env.block_sizes[i] for i in block_indices]
loop_order = block_size_infos[0].get_order(config, len(block_size_infos))
if block_size_infos[0].is_flattened(config):
if block_size_infos[0].is_grid():
strategy: TileStrategy = NDGridTileStrategy(
fn,
block_indices,
loop_order=loop_order,
)
elif block_size_infos[0].is_flattened(config):
strategy: TileStrategy = FlattenedTileStrategy(
fn,
block_indices,
Expand Down
110 changes: 74 additions & 36 deletions helion/_compiler/tile_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,7 @@ def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]:
return output


class NDTileStrategy(BlockSizeTileStrategy):
"""Do up to 3D tiling using the kernel grid."""

class _BaseNDTileStrategy(BlockSizeTileStrategy):
block_size: list[SymIntLike]

def __init__(
Expand All @@ -357,21 +355,15 @@ def __init__(
block_indices: list[int],
block_size: list[SymIntLike] | SymIntLike,
loop_order: list[int],
l2_grouping: int,
) -> None:
assert isinstance(block_size, list)
super().__init__(fn, block_indices, block_size, loop_order)
self.mask_vars: dict[int, str | None] = {}
self.l2_grouping = l2_grouping
for bs, block_idx in zip(block_size, block_indices, strict=True):
if (block_idx,) not in fn.block_size_var_cache and bs != 1:
fn.block_size_var_cache[(block_idx,)] = fn.new_var(
f"_BLOCK_SIZE_{block_idx}"
)

def mask_var(self, block_idx: int) -> str | None:
return self.mask_vars[block_idx]

def codegen_grid(self, state: CodegenState) -> None:
block_indices = self.block_indices
env = CompileEnvironment.current()
Expand Down Expand Up @@ -408,34 +400,16 @@ def codegen_grid(self, state: CodegenState) -> None:
state.add_statement(
f"{index_var} = {offset_var} + tl.zeros([1], {dtype})"
)
mask_statement = self._setup_mask(state, block_idx, block_size, index_var)
if mask_statement is not None:
state.add_statement(mask_statement)
if hasattr(self, "_setup_mask"):
mask_statement = self._setup_mask( # pyre-ignore[16]
state, block_idx, block_size, index_var
)
if mask_statement is not None:
state.add_statement(mask_statement)
pids.append(ProgramID(pid_var, block_size_var, numel))
pids.codegen(state)

def _setup_mask(
self,
state: CodegenState,
block_idx: int,
block_size: SymIntLike,
index_var: str,
) -> ast.stmt | None:
env = CompileEnvironment.current()
numel = env.block_sizes[block_idx].numel
if block_size == 1 or env.known_multiple(numel, block_size):
self.mask_vars[block_idx] = None
return None
self.mask_vars[block_idx] = mask_var = self.fn.new_var(
f"mask_{block_idx}", dce=True
)
return statement_from_string(
f"{mask_var} = ({index_var} < ({state.device_function.sympy_expr(numel)}))"
)

def select_pid_strategy(self) -> ProgramIDs:
if self.l2_grouping > 1:
return L2GroupingProgramIDs(group_size=self.l2_grouping)
if 1 < len(self.block_indices) <= 3 and self.fn.config.use_yz_grid:
return GridProgramIDs()
return VirtualProgramIDs()
Expand Down Expand Up @@ -483,9 +457,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
f"{index_var} = {offset_var} + tl.arange(0, ({block_size_var})).to({dtype})"
),
]
mask_statement = self._setup_mask(state, block_idx, block_size, index_var)
if mask_statement is not None:
extra_body.append(mask_statement)
if hasattr(self, "_setup_mask"):
mask_statement = self._setup_mask( # pyre-ignore[16]
state, block_idx, block_size, index_var
)
if mask_statement is not None:
extra_body.append(mask_statement)
body[:] = [*extra_body, *body]
body = [for_node]
assert for_node is not None
Expand All @@ -500,6 +477,67 @@ def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]:
return shapes


class NDTileStrategy(_BaseNDTileStrategy):
"""Do up to 3D tiling using the kernel grid."""

def __init__(
self,
fn: DeviceFunction,
block_indices: list[int],
block_size: list[SymIntLike] | SymIntLike,
loop_order: list[int],
l2_grouping: int,
) -> None:
super().__init__(fn, block_indices, block_size, loop_order)
self.mask_vars: dict[int, str | None] = {}
self.l2_grouping = l2_grouping

def mask_var(self, block_idx: int) -> str | None:
return self.mask_vars[block_idx]

def _setup_mask(
self,
state: CodegenState,
block_idx: int,
block_size: SymIntLike,
index_var: str,
) -> ast.stmt | None:
env = CompileEnvironment.current()
numel = env.block_sizes[block_idx].numel
if block_size == 1 or env.known_multiple(numel, block_size):
self.mask_vars[block_idx] = None
return None
self.mask_vars[block_idx] = mask_var = self.fn.new_var(
f"mask_{block_idx}", dce=True
)
return statement_from_string(
f"{mask_var} = ({index_var} < ({state.device_function.sympy_expr(numel)}))"
)

def select_pid_strategy(self) -> ProgramIDs:
if self.l2_grouping > 1:
return L2GroupingProgramIDs(group_size=self.l2_grouping)
return super().select_pid_strategy()


class NDGridTileStrategy(_BaseNDTileStrategy):
def __init__(
self,
fn: DeviceFunction,
block_indices: list[int],
loop_order: list[int],
) -> None:
super().__init__(
fn=fn,
block_indices=block_indices,
block_size=[1] * len(block_indices), # pyre-ignore[6]
loop_order=loop_order,
)

def mask_var(self, block_idx: int) -> str | None:
return None


class CompactedShape(NamedTuple):
size_str: str
user_indices: list[int]
Expand Down
33 changes: 33 additions & 0 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,39 @@ def merge(self, other: TypeInfo) -> TypeInfo:
return super().merge(other)


class GridIndexType(SymIntType):
block_size_idx: int

def __init__(self, origin: Origin, block_size_idx: int) -> None:
from .._compiler.compile_environment import CompileEnvironment

env = CompileEnvironment.current()
super().__init__(origin, env.block_sizes[block_size_idx].var)
self.block_size_idx = block_size_idx

def __str__(self) -> str: # pragma: no cover – debug helper
return f"{type(self).__name__}({self.block_size_idx})"

@staticmethod
def allocate(numel: int | torch.SymInt, origin: Origin) -> GridIndexType:
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.compile_environment import GridBlockSizeSource

env = CompileEnvironment.current()
block_idx = env.allocate_block_size(numel, source=GridBlockSizeSource())
return GridIndexType(origin, block_idx)

def merge(self, other: TypeInfo) -> TypeInfo: # type: ignore[override]
if isinstance(other, GridIndexType):
if self.block_size_idx == other.block_size_idx:
return self
return UnknownType(
debug_msg=f"GridIndexType mismatch in control flow: {self.block_size_idx} vs {other.block_size_idx}",
origin=other.origin,
)
return super().merge(other)


class IterType(TypeInfo):
inner: TypeInfo

Expand Down
1 change: 1 addition & 0 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .constexpr import ConstExpr as constexpr # noqa: F401
from .creation_ops import full as full
from .creation_ops import zeros as zeros
from .loops import grid as grid
from .loops import register_block_size as register_block_size
from .loops import tile as tile
from .memory_ops import load as load
Expand Down
Loading
Loading
0