8000 Add support for hl.tile(begin, end) and hl.tile(begin, end, block_size) by jansel · Pull Request #82 · pytorch-labs/helion · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add support for hl.tile(begin, end) and hl.tile(begin, end, block_size) #82

New issue 8000

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ class NoCurrentEnvironment(RuntimeError):
pass


class BlockSizeInfo(typing.NamedTuple):
@dataclasses.dataclass
class BlockSizeInfo:
"""
Information about a block size.
Used to track the block size for a given dimension.
Expand All @@ -320,8 +321,14 @@ def known_multiple(self, block_size: int | torch.SymInt) -> bool:
return CompileEnvironment.current().known_multiple(self.numel, block_size)

def size_hint(self) -> int:
assert self.size is not None
return CompileEnvironment.current().size_hint(self.size)
size = self.size
assert size is not None
return CompileEnvironment.current().size_hint(size)

def mark_alternate_size(self, size: torch.SymInt | int | None) -> None:
"""If a block size is used with a different size, we need to clear the hint to enable masking."""
if size is None or self.size is None or self.size != size:
self.size = None

def symbol(self) -> sympy.Symbol:
return self.var._sympy_()
Expand Down
2 changes: 1 addition & 1 deletion helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ def allocate(

@staticmethod
def allocate_fixed(
numel: int | torch.SymInt, block_size: int | torch.SymInt, origin: Origin
numel: int | torch.SymInt | None, block_size: int | torch.SymInt, origin: Origin
) -> TileIndexType:
env = CompileEnvironment.current()
return TileIndexType(
Expand Down
225 changes: 137 additions & 88 deletions helion/language/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING
from typing import Iterator
from typing import Sequence
from typing import TypeGuard
from typing import overload

import torch
Expand All @@ -12,14 +13,12 @@
from .._compiler.ast_extension import ExtendedAST
from .._compiler.ast_extension import LoopType
from .._compiler.ast_extension import expr_from_string
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.tile_index_proxy import TileIndexProxy
from .._compiler.type_propagation import GridIndexType
from .._compiler.type_propagation import IterType
from .._compiler.type_propagation import LiteralType
from .._compiler.type_propagation import Origin
from .._compiler.type_propagation import SequenceType
from .._compiler.type_propagation import SymIntType
from .._compiler.type_propagation import TensorType
from .._compiler.type_propagation import TileIndexType
from .._compiler.type_propagation import TypeInfo
from .._compiler.type_propagation import UnknownType
Expand All @@ -40,23 +39,32 @@
@_decorators.api(
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
)
def tile(sizes: int, /, block_size: object = None) -> Iterator[TileOutput]: ...
def tile(
begin_or_end: int,
end_or_none: int | None = None,
/,
block_size: object = None,
) -> Iterator[TileOutput]: ...


@overload
@_decorators.api(
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
)
def tile(
sizes: Sequence[int], /, block_size: object = None
begin_or_end: Sequence[int],
end_or_none: Sequence[int] | None = None,
/,
block_size: object = None,
) -> Iterator[Sequence[TileOutput]]: ...


@_decorators.api(
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
)
def tile(
sizes: int | Sequence[int],
begin_or_end: int | Sequence[int],
end_or_none: int | Sequence[int] | None = None,
/,
block_size: object = None,
) -> Iterator[TileOutput] | Iterator[Sequence[TileOutput]]:
Expand All @@ -73,6 +81,16 @@ def tile(
If used at the top level of a function, this becomes the grid of the kernel.
Otherwise, it becomes a loop in the output kernel.

Similar to `range()` there are multiple forms of this function:
tile(end) iterates from 0 to `end - 1`, with autotuned block_size.
tile(begin, end) iterates from `begin` to `end - 1`, with autotuned block_size.
tile(begin, end, block_size) iterates from `begin` to `end - 1`, with the given block_size.
tile(end, block_size=block_size) iterates from 0 to `end - 1`, with the given block_size.

begin/end/block_size can be a single integer or a sequence of integers to specify
multidimensional iteration. Block sizes can be explicitly registered for autotuning
with `hl.register_block_size()`.

Examples:

for tile in hl.tile(1000):
Expand All @@ -81,51 +99,116 @@ def tile(
for tile0, tile1 in hl.tile([1000, 1000]):
...

:param sizes: An integer or a sequence of integers representing the sizes for tiling.
:param begin_or_end: If 2 or more positional arguments are provided, the start of the iteration space. Otherwise, the end of the iteration space.
:param end_or_none: If 2 or more positional arguments are provided, the end of the iteration space.
:return: A TileIndexProtocol object if a single size is provided, or a sequence of TileIndexProtocol objects if a sequence of sizes is provided.
"""
raise exc.NotInsideKernel


def _not_none(value: TypeInfo | None) -> TypeGuard[TypeInfo]:
return not (value is None or value.is_literal() and value.as_literal() is None)


def _to_proxy(value: TypeInfo) -> object:
try:
return value.proxy()
except NotImplementedError:
raise exc.IncorrectTileUsage(
f"expected IntLike or list[IntLike], got {value!s}"
) from None


def _check_matching(a: object, b: object) -> None:
"""Check that the types of `a` and `b` match for use in hl.tile."""
if isinstance(a, (list, tuple)):
if not isinstance(b, (list, tuple)):
raise exc.IncorrectTileUsage(
f"expected type hl.tile args to match, got {type(a)} and {type(b)}"
)
if len(a) != len(b):
raise exc.IncorrectTileUsage(
f"expected dims for hl.tile args to match, got {len(a)} and {len(b)}"
)
elif isinstance(a, (int, torch.SymInt, torch.Tensor)):
if not isinstance(b, (int, torch.SymInt, torch.Tensor)):
raise exc.IncorrectTileUsage(
f"expected type hl.tile args to match, got {type(a)} and {type(b)}"
)
else:
raise exc.IncorrectTileUsage(
f"expected type hl.tile args to be IntLike or list[IntLike], got {type(a)}"
)


def _normalize_begin_end(
begin_or_end: TypeInfo,
end_or_none: TypeInfo | None,
origin: Origin,
) -> tuple[TypeInfo, TypeInfo]:
"""Fill in defaults for begin if it is not provided."""
if _not_none(end_or_none):
begin = begin_or_end
end = end_or_none
else:
try:
begin = TypeInfo.from_example(begin_or_end.tree_map(lambda n: 0), origin)
except NotImplementedError:
raise exc.TypePropagationError(
UnknownType(
origin,
f"expected IntLike or list[IntLike], got {begin_or_end!s}",
chained_from=begin_or_end,
)
) from None
end = begin_or_end
return begin, end


@_decorators.type_propagation(tile)
def _(
sizes: TypeInfo, block_size: TypeInfo | None = None, *, origin: Origin
begin_or_end: TypeInfo,
end_or_none: TypeInfo | None = None,
/,
block_size: TypeInfo | None = None,
*,
origin: Origin,
) -> TypeInfo:
parent = ExtendedAST.current()[-2]
if not isinstance(parent, ast.For):
raise exc.LoopFunctionNotInFor("tile")
begin, end = _normalize_begin_end(begin_or_end, end_or_none, origin=origin)
proxy_begin = _to_proxy(begin)
proxy_end = _to_proxy(end)
_check_matching(proxy_begin, proxy_end)
if _not_none(block_size):
proxy_block_size = TileIndexProxy.tiles_to_sizes(_to_proxy(block_size))
_check_matching(proxy_end, proxy_block_size)
else:
proxy_block_size = begin.tree_map(lambda n: None)

if unpack := not isinstance(proxy_end, (list, tuple)):
proxy_begin = [proxy_begin]
proxy_end = [proxy_end]
proxy_block_size = [proxy_block_size]

if (
block_size is None
or block_size.is_literal()
and block_size.as_literal() is None
all(bs is None for bs in proxy_block_size)
and all(isinstance(s, (int, torch.SymInt)) for s in proxy_begin)
and all(isinstance(s, (int, torch.SymInt)) for s in proxy_end)
):
result = _register_block_size_types(sizes, origin)
proxy_size = [e - b for b, e in zip(proxy_begin, proxy_end, strict=True)]
results = TileIndexType.allocate(proxy_size, origin)
else:
try:
proxy_sizes = sizes.proxy()
proxy_block_size = TileIndexProxy.tiles_to_sizes(block_size.proxy())
except NotImplementedError:
raise exc.IncorrectTileUsage(
f"expected int or list[int], got {sizes!s} and {block_size!s}"
) from None
if isinstance(proxy_sizes, (list, tuple)):
if not isinstance(proxy_block_size, (list, tuple)) or len(
proxy_sizes
) != len(proxy_block_size):
raise exc.IncorrectTileUsage(
f"expected dims for sizes and block_sizes to match, got {sizes!s} and {block_size!s}"
)
unpack = False
else:
if not isinstance(proxy_block_size, int | torch.SymInt):
raise exc.IncorrectTileUsage(
f"expected type for sizes and block_sizes to match, got {sizes!s} and {block_size!s}"
)
proxy_sizes = [proxy_sizes]
proxy_block_size = [proxy_block_size]
unpack = True
# we must allocate the block sizes individually due to data dependent size or pre-allocated block sizes
# TODO(jansel): this flattens the structure of the config, which we should avoid
results = []
for size, bs in zip(proxy_sizes, proxy_block_size, strict=True):
for begin_part, end_part, bs in zip(
proxy_begin, proxy_end, proxy_block_size, strict=True
):
size = end_part - begin_part
if isinstance(size, torch.Tensor):
size = None # data dependent size
if bs is None:
results.append(TileIndexType.allocate([size], origin)[0])
elif isinstance(bs, int):
Expand All @@ -138,59 +221,14 @@ def _(
results.append(TileIndexType.allocate_fixed(size, bs, origin))
else:
results.append(TileIndexType(origin=origin, block_size_idx=index))
if unpack:
(result,) = results
else:
result = SequenceType(origin, results)
return IterType(origin, result)


def _register_block_size_types(sizes: TypeInfo, origin: Origin) -> TypeInfo:
if isinstance(sizes, SequenceType):
unpacked = sizes.unpack()
CompileEnvironment.current().block_sizes[index].mark_alternate_size(
size
)
if unpack:
(result,) = results
else:
unpacked = [sizes]
has_data_dependency = False
for size in unpacked:
if isinstance(size, TensorType) and size.origin.is_device():
has_data_dependency = True
elif isinstance(size, (LiteralType, SymIntType)) and isinstance(
size.proxy(), (int, torch.SymInt)
):
pass
else:
raise exc.TypePropagationError(
UnknownType(
origin,
f"tile() expected int or list[int], got {size!s}",
chained_from=size,
)
)
if has_data_dependency:
# TODO(jansel): support flatten/reorder for data dependencies
inner_types: list[TypeInfo] = []
for size in unpacked:
if isinstance(size, TensorType) and size.origin.is_device():
proxy = None
else:
proxy = size.proxy()
assert isinstance(proxy, (int, torch.SymInt))
inner_types.append(TileIndexType.allocate([proxy], origin)[0])
if isinstance(sizes, SequenceType):
return SequenceType(
origin=origin,
element_types=inner_types,
)
assert len(inner_types) == 1
return inner_types[0]
proxy_sizes = sizes.proxy()
if isinstance(proxy_sizes, (int, torch.SymInt)):
return TileIndexType.allocate([proxy_sizes], origin)[0]
return SequenceType(
origin=origin,
# pyre-fixme[6]
element_types=TileIndexType.allocate(proxy_sizes, origin),
)
result = SequenceType(origin, results)
return IterType(origin, result)


def _get_block_indices(type_info: TypeInfo) -> list[int]:
Expand Down Expand Up @@ -334,6 +372,17 @@ def register_block_size(size: int | Sequence[int]) -> TileOutput | Sequence[Tile
raise exc.NotInsideKernel


def _register_block_size_types(sizes: TypeInfo, origin: Origin) -> TypeInfo:
proxy_sizes = sizes.proxy()
if isinstance(proxy_sizes, (int, torch.SymInt)):
return TileIndexType.allocate([proxy_sizes], origin)[0]
return SequenceType(
origin=origin,
# pyre-fixme[6]
element_types=TileIndexType.allocate(proxy_sizes, origin),
)


@_decorators.type_propagation(register_block_size)
def _(sizes: TypeInfo, *, orig 5045 in: Origin) -> TypeInfo:
return _register_block_size_types(sizes, origin)
Loading
Loading
0