8000 Add static_range by joydddd · Pull Request #235 · pytorch-labs/helion · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add static_range #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ Contains one entry per loop dimension, controlling the `flatten`
parameter for `tl.range()` calls. `True` sets `flatten=True`,
`False` sets `flatten=False`, and `None` omits the parameter.

* **static\_ranges** (`list[bool]`):
Contains one entry per loop dimension, controlling whether to use
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should mention this is only legal for loops with static bounds

`tl.static_range()` calls. `True` uses `tl.static_range()`, `False` uses `tl.range()`.

* **range\_warp\_specializes** (`list[bool | None]`):
Contains one entry per loop dimension, controlling the `warp_specialize`
parameter for `tl.range()` calls. `True` sets `warp_specialize=True`,
Expand Down
15 changes: 13 additions & 2 deletions helion/_compiler/tile_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ def fn(self) -> DeviceFunction:
assert fn is not None
return fn

def get_range_fn_name(self, state: CodegenState, block_idx: int) -> str:
env = CompileEnvironment.current()
range_static = env.config_spec.static_ranges.config_get(
state.config.static_ranges, block_idx, None
)
if range_static is True:
return "tl.static_range"
return "tl.range"

def offset_var(self, block_idx: int) -> str:
return self.offset_vars[block_idx]

Expand Down Expand Up @@ -400,11 +409,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
dtype = CompileEnvironment.current().triton_index_type()
lid = self.new_var("lid")
range_extra = self.get_tl_range_kwargs(state, self.block_ids[0])
range_fn = self.get_range_fn_name(state, self.block_ids[0])
for_node = create(
ast.For,
target=create(ast.Name, id=lid, ctx=ast.Store()),
iter=expr_from_string(
f"tl.range(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}){range_extra})"
f"{range_fn}(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}){range_extra})"
),
body=(
body := [
Expand Down Expand Up @@ -610,11 +620,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
)

range_extra = self.get_tl_range_kwargs(state, block_idx)
range_fn = self.get_range_fn_name(state, self.block_ids[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be block_idx? Or do we always want to match the first loop in an multidimension loopnest?

for_node = create(
ast.For,
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
iter=expr_from_string(
f"tl.range(begin, end, {block_size_var}{range_extra})",
f"{range_fn}(begin, end, {block_size_var}{range_extra})",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should refactor the {range_extra} helper to generate the entire expression (take in begin/end/step as args). That would put all the loop generation logic in one place.

begin=self._to_ast(begin, to_dtype=dtype),
end=self._to_ast(end, to_dtype=dtype),
),
Expand Down
4 changes: 4 additions & 0 deletions helion/autotuner/block_id_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def block_id_lookup(self, block_id: int) -> _BlockIdItemT:
"""Return the index of the block_id in the config."""
return self._data[self._block_id_to_index[block_id]]

def valid_block_ids(self) -> list[int]:
"""Return the list of valid block_ids."""
return list(self._block_id_to_index.keys())

def disable_block_id(self, block_id: int) -> None:
"""Remove configuration choice for the given block_id."""
self._data = [x for x in self._data if block_id not in x.block_ids]
Expand Down
60 changes: 60 additions & 0 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"range_num_stages",
"range_multi_buffers",
"range_flattens",
"static_ranges",
"num_warps",
"num_stages",
"use_yz_grid",
Expand Down Expand Up @@ -81,6 +82,9 @@ class ConfigSpec:
range_flattens: BlockIdSequence[RangeFlattenSpec] = dataclasses.field(
default_factory=BlockIdSequence
)
static_ranges: BlockIdSequence[StaticRangeSpec] = dataclasses.field(
default_factory=BlockIdSequence
)
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field(
default_factory=dict
)
Expand All @@ -95,6 +99,7 @@ def _remove_duplicates(self) -> None:
self.range_num_stages._remove_duplicates()
self.range_multi_buffers._remove_duplicates()
self.range_flattens._remove_duplicates()
self.static_ranges._remove_duplicates()

def normalize(self, config: helion.Config | dict[str, object]) -> None:
"""Normalize the config to match the block_sizes and validate the config."""
Expand All @@ -113,6 +118,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
"range_num_stage",
"range_multi_buffer",
"range_flatten",
"static_range",
):
if name in config:
names = f"{name}s"
Expand All @@ -131,11 +137,32 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
("range_num_stages", self.range_num_stages, True),
("range_multi_buffers", self.range_multi_buffers, True),
("range_flattens", self.range_flattens, True),
("static_ranges", self.static_ranges, True),
]:
config[name] = mapping._normalize(
name, config.get(name, ()), flatten=flatten
)

for block_id in self.static_ranges.valid_block_ids():
use_static_range = self.static_ranges.config_get(
config.get("static_ranges", ()), # pyre-ignore[6]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a property so you don't need to typing ignore here:

Suggested change
config.get("static_ranges", ()), # pyre-ignore[6]
config.static_ranges

block_id,
)

if use_static_range:
for name, mapping in (
("range_unroll_factors", self.range_unroll_factors),
("range_warp_specializes", self.range_warp_specialize),
("range_num_stages", self.range_num_stages),
("range_multi_buffers", self.range_multi_buffers),
("range_flattens", self.range_flattens),
):
if config[name]: # The config is non empty
# pyre-ignore[16]
config[name][mapping.block_id_to_index(block_id)] = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The block_id may not exist in every one of these, since you can disable choices via the config_spec API. Handle the case of it missing.

Maybe add a mapping.set_config_to_default(config[name], block_id) helper for this.

mapping.block_id_lookup(block_id)._fill_missing()
)

for name in (
"loop_orders",
"l2_groupings",
Expand All @@ -146,6 +173,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
"range_num_stages",
"range_multi_buffers",
"range_flattens",
"static_ranges",
):
if not config[name]:
config.pop(name)
Expand Down Expand Up @@ -180,6 +208,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"range_num_stages": self.range_num_stages._flat_config(self, fn),
"range_multi_buffers": self.range_multi_buffers._flat_config(self, fn),
"range_flattens": self.range_flattens._flat_config(self, fn),
"static_ranges": self.static_ranges._flat_config(self, fn),
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)),
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
"indexing": fn(
Expand Down Expand Up @@ -211,6 +240,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"range_num_stages",
"range_multi_buffers",
"range_flattens",
"static_ranges",
):
if not config[name]:
config.pop(name)
Expand Down Expand Up @@ -399,6 +429,36 @@ class RangeFlattenSpec(_OptionalBoolSpec):
pass


class StaticRangeSpec(_BlockIdItem):
def __init__(
self,
block_id: int,
is_static: bool,
) -> None:
super().__init__([block_id])
self.is_static = is_static
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If is_static is False, don't create the StaticRangeSpec.

We shouldn't add config values with only a single choice.


def _fragment(self, base: ConfigSpec) -> ConfigSpecFragment:
if (
self.is_static
): # Only enable tl.static_range when loop parameters are static
return BooleanFragment()
return EnumFragment((False,))

def _normalize(self, name: str, value: object) -> bool:
if not isinstance(value, bool):
raise InvalidConfig(f"{name} must be a boolean, got {value!r}")
if value is True and self.is_static is False:
raise InvalidConfig(
f"Got {name}=Ture for non-static loop #{self.block_id}\n Do you forget to call hl.specialize() on the loop dim? "
)
return value

def _fill_missing(self) -> bool:
"""Provide a value when not provided by the user."""
return False


def _product(seq: Sequence[int]) -> int:
"""Return the product of the elements in the sequence."""
return functools.reduce(operator.mul, seq, 1)
33 changes: 32 additions & 1 deletion helion/language/loops.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ..autotuner.config_spec import RangeNumStagesSpec
from ..autotuner.config_spec import RangeUnrollFactorSpec
from ..autotuner.config_spec import RangeWarpSpecializeSpec
from ..autotuner.config_spec import StaticRangeSpec
from . import _decorators
from helion.language.tile_proxy import Tile

Expand Down Expand Up @@ -151,6 +152,23 @@ def _check_matching(a: object, b: object) -> None:
)


def _is_constexpr_int(a: object) -> bool:
"""Check if the arg is specialized."""
return isinstance(a, int)
# TODO(joydddd): render SymInt backed by Int as constexpr.
# Now the specialized constexpr is assigned to a dynamic variable first
# and then used as a variable. However args to static_range must be constexpr.
# e.g.
# hl.specialize(x.size(0))
# for i in hl.grid(x.size(0))
# ->
# symbol_0 = 64
# for i in tl.static_range(symbol_0):
#
# if isinstance(a, torch.SymInt):
# return isinstance(a._sympy_(), sympy.Integer)


def _normalize_begin_end(
begin_or_end: TypeInfo,
end_or_none: TypeInfo | None,
Expand Down Expand Up @@ -225,6 +243,10 @@ def _(
[x.block_id for x in results],
is_tile=True,
has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin),
is_static=all(
_is_constexpr_int(x) or x is None
for x in (*proxy_begin, *proxy_end, *proxy_block_size)
),
)
if unpack:
(result,) = results
Expand All @@ -234,7 +256,11 @@ def _(


def _add_config_choices(
block_ids: list[int], *, is_tile: bool = False, has_begin: bool = False
block_ids: list[int],
*,
is_tile: bool = False,
has_begin: bool = False,
is_static: bool = False,
) -> None:
config_spec = CompileEnvironment.current().config_spec

Expand All @@ -253,6 +279,7 @@ def _add_config_choices(
else:
params = inspect.signature(triton.language.range).parameters
for block_id in block_ids:
config_spec.static_ranges.append(StaticRangeSpec(block_id, is_static))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
config_spec.static_ranges.append(StaticRangeSpec(block_id, is_static))
if is_static:
config_spec.static_ranges.append(StaticRangeSpec(block_id))

if "loop_unroll_factor" in params:
config_spec.range_unroll_factors.append(
RangeUnrollFactorSpec([block_id])
Expand Down Expand Up @@ -419,6 +446,10 @@ def _(
[x.block_id for x in results],
is_tile=False,
has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin),
is_static=all(
_is_constexpr_int(x) or x is None
for x in (*proxy_begin, *proxy_end, *proxy_step)
),
)
if unpack:
(result,) = results
Expand Down
7 changes: 7 additions & 0 deletions helion/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
range_num_stages: list[int] | None = None,
range_multi_buffers: list[bool | None] | None = None,
range_flattens: list[bool | None] | None = None,
static_ranges: list[bool] | None = None,
num_warps: int | None = None,
num_stages: int | None = None,
use_yz_grid: bool | None = None,
Expand All @@ -50,6 +51,7 @@ def __init__(
range_num_stages: Number of stages for tl.range calls.
range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls.
range_flattens: Controls flatten parameter for tl.range calls.
static_ranges: Whether to use tl.static_range instead tl.range.
num_warps: Number of warps per block.
num_stages: Number of stages for software pipelining.
use_yz_grid: Whether to use yz grid dimensions.
Expand All @@ -68,6 +70,7 @@ def __init__(
"range_num_stages": range_num_stages,
"range_multi_buffers": range_multi_buffers,
"range_flattens": range_flattens,
"static_ranges": static_ranges,
"num_warps": num_warps,
"num_stages": num_stages,
"indexing": indexing,
Expand Down Expand Up @@ -173,6 +176,10 @@ def range_multi_buffers(self) -> list[bool | None]:
def range_flattens(self) -> list[bool | None]:
return cast("list[bool | None]", self.config.get("range_flattens", []))

@property
def static_ranges(self) -> list[bool]:
return cast("list[bool]", self.config.get("static_ranges", []))

@property
def indexing(self) -> IndexingLiteral:
return self.config.get("indexing", "pointer") # type: ignore
Expand Down
20 changes: 10 additions & 10 deletions test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ def test_config_fragment0(self):
self.assertExpectedInline(
"\n".join(map(repr, configs)),
"""\
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_warp_specializes=[False], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[False], num_warps=1, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[1], range_warp_specializes=[True], range_num_stages=[4], range_multi_buffers=[True], range_flattens=[True], num_warps=2, num_stages=8, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[True], range_num_stages=[1], range_multi_buffers=[False], range_flattens=[False], num_warps=32, num_stages=2, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 32, 64], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[2], range_warp_specializes=[True], range_num_stages=[3], range_multi_buffers=[True], range_flattens=[None], num_warps=4, num_stages=7, indexing='pointer')
helion.Config(block_sizes=[256, 128, 16], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_warp_specializes=[True], range_num_stages=[4], range_multi_buffers=[None], range_flattens=[False], num_warps=8, num_stages=4, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[16], range_unroll_factors=[0], range_warp_specializes=[True], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[False], num_warps=1, num_stages=8, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[False], num_warps=4, num_stages=4, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 64], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[True], num_warps=16, num_stages=4, indexing='block_ptr')
helion.Config(block_sizes=[32, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[False], num_warps=4, num_stages=1, indexing='tensor_descriptor')""",
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], static_ranges=[False], num_warps=4, num_stages=3, indexing='pointer')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[True], static_ranges=[True], num_warps=1, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[64, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[1], range_num_stages=[4], range_multi_buffers=[True], range_flattens=[True], static_ranges=[False], num_warps=32, num_stages=8, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[1], range_multi_buffers=[False], range_flattens=[False], static_ranges=[False], num_warps=16, num_stages=1, indexing='pointer')
helion.Config(block_sizes=[16, 128, 64], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[2], range_num_stages=[3], range_multi_buffers=[True], range_flattens=[None], static_ranges=[True], num_warps=16, num_stages=7, indexing='pointer')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_num_stages=[4], range_multi_buffers=[None], range_flattens=[False], static_ranges=[True], num_warps=2, num_stages=3, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], range_num_stages=[0], range_multi_buffers=[False], range_flattens=[None], static_ranges=[True], num_warps=16, num_stages=3, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[2], range_num_stages=[3], range_multi_buffers=[False], range_flattens=[False], static_ranges=[False], num_warps=32, num_stages=5, indexing='pointer')
helion.Config(block_sizes=[16, 16, 32], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[0], range_num_stages=[1], range_multi_buffers=[False], range_flattens=[False], static_ranges=[False], num_warps=8, num_stages=6, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 32], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[4], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[None], static_ranges=[True], num_warps=8, num_stages=5, indexing='tensor_descriptor')""",
)

@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)
Expand Down
Loading
Loading
0