-
Notifications
You must be signed in to change notification settings - Fork 13
Add static_range #235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add static_range #235
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,6 +111,15 @@ def fn(self) -> DeviceFunction: | |
assert fn is not None | ||
return fn | ||
|
||
def get_range_fn_name(self, state: CodegenState, block_idx: int) -> str: | ||
env = CompileEnvironment.current() | ||
range_static = env.config_spec.static_ranges.config_get( | ||
state.config.static_ranges, block_idx, None | ||
) | ||
if range_static is True: | ||
return "tl.static_range" | ||
return "tl.range" | ||
|
||
def offset_var(self, block_idx: int) -> str: | ||
return self.offset_vars[block_idx] | ||
|
||
|
@@ -400,11 +409,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: | |
dtype = CompileEnvironment.current().triton_index_type() | ||
lid = self.new_var("lid") | ||
range_extra = self.get_tl_range_kwargs(state, self.block_ids[0]) | ||
range_fn = self.get_range_fn_name(state, self.block_ids[0]) | ||
for_node = create( | ||
ast.For, | ||
target=create(ast.Name, id=lid, ctx=ast.Store()), | ||
iter=expr_from_string( | ||
f"tl.range(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}){range_extra})" | ||
f"{range_fn}(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}){range_extra})" | ||
), | ||
body=( | ||
body := [ | ||
|
@@ -610,11 +620,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: | |
) | ||
|
||
range_extra = self.get_tl_range_kwargs(state, block_idx) | ||
range_fn = self.get_range_fn_name(state, self.block_ids[0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be |
||
for_node = create( | ||
ast.For, | ||
target=create(ast.Name, id=offset_var, ctx=ast.Store()), | ||
iter=expr_from_string( | ||
f"tl.range(begin, end, {block_size_var}{range_extra})", | ||
f"{range_fn}(begin, end, {block_size_var}{range_extra})", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should refactor the {range_extra} helper to generate the entire expression (take in begin/end/step as args). That would put all the loop generation logic in one place. |
||
begin=self._to_ast(begin, to_dtype=dtype), | ||
end=self._to_ast(end, to_dtype=dtype), | ||
), | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -41,6 +41,7 @@ | |||||
"range_num_stages", | ||||||
"range_multi_buffers", | ||||||
"range_flattens", | ||||||
"static_ranges", | ||||||
"num_warps", | ||||||
"num_stages", | ||||||
"use_yz_grid", | ||||||
|
@@ -81,6 +82,9 @@ class ConfigSpec: | |||||
range_flattens: BlockIdSequence[RangeFlattenSpec] = dataclasses.field( | ||||||
default_factory=BlockIdSequence | ||||||
) | ||||||
static_ranges: BlockIdSequence[StaticRangeSpec] = dataclasses.field( | ||||||
default_factory=BlockIdSequence | ||||||
) | ||||||
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field( | ||||||
default_factory=dict | ||||||
) | ||||||
|
@@ -95,6 +99,7 @@ def _remove_duplicates(self) -> None: | |||||
self.range_num_stages._remove_duplicates() | ||||||
self.range_multi_buffers._remove_duplicates() | ||||||
self.range_flattens._remove_duplicates() | ||||||
self.static_ranges._remove_duplicates() | ||||||
|
||||||
def normalize(self, config: helion.Config | dict[str, object]) -> None: | ||||||
"""Normalize the config to match the block_sizes and validate the config.""" | ||||||
|
@@ -113,6 +118,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: | |||||
"range_num_stage", | ||||||
"range_multi_buffer", | ||||||
"range_flatten", | ||||||
"static_range", | ||||||
): | ||||||
if name in config: | ||||||
names = f"{name}s" | ||||||
|
@@ -131,11 +137,32 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: | |||||
("range_num_stages", self.range_num_stages, True), | ||||||
("range_multi_buffers", self.range_multi_buffers, True), | ||||||
("range_flattens", self.range_flattens, True), | ||||||
("static_ranges", self.static_ranges, True), | ||||||
]: | ||||||
config[name] = mapping._normalize( | ||||||
name, config.get(name, ()), flatten=flatten | ||||||
) | ||||||
|
||||||
for block_id in self.static_ranges.valid_block_ids(): | ||||||
use_static_range = self.static_ranges.config_get( | ||||||
config.get("static_ranges", ()), # pyre-ignore[6] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a property so you don't need to typing ignore here:
Suggested change
|
||||||
block_id, | ||||||
) | ||||||
|
||||||
if use_static_range: | ||||||
for name, mapping in ( | ||||||
("range_unroll_factors", self.range_unroll_factors), | ||||||
("range_warp_specializes", self.range_warp_specialize), | ||||||
("range_num_stages", self.range_num_stages), | ||||||
("range_multi_buffers", self.range_multi_buffers), | ||||||
("range_flattens", self.range_flattens), | ||||||
): | ||||||
if config[name]: # The config is non empty | ||||||
# pyre-ignore[16] | ||||||
config[name][mapping.block_id_to_index(block_id)] = ( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The block_id may not exist in every one of these, since you can disable choices via the config_spec API. Handle the case of it missing. Maybe add a |
||||||
mapping.block_id_lookup(block_id)._fill_missing() | ||||||
) | ||||||
|
||||||
for name in ( | ||||||
"loop_orders", | ||||||
"l2_groupings", | ||||||
|
@@ -146,6 +173,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: | |||||
"range_num_stages", | ||||||
"range_multi_buffers", | ||||||
"range_flattens", | ||||||
"static_ranges", | ||||||
): | ||||||
if not config[name]: | ||||||
config.pop(name) | ||||||
|
@@ -180,6 +208,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf | |||||
"range_num_stages": self.range_num_stages._flat_config(self, fn), | ||||||
"range_multi_buffers": self.range_multi_buffers._flat_config(self, fn), | ||||||
"range_flattens": self.range_flattens._flat_config(self, fn), | ||||||
"static_ranges": self.static_ranges._flat_config(self, fn), | ||||||
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)), | ||||||
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)), | ||||||
"indexing": fn( | ||||||
|
@@ -211,6 +240,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf | |||||
"range_num_stages", | ||||||
"range_multi_buffers", | ||||||
"range_flattens", | ||||||
"static_ranges", | ||||||
): | ||||||
if not config[name]: | ||||||
config.pop(name) | ||||||
|
@@ -399,6 +429,36 @@ class RangeFlattenSpec(_OptionalBoolSpec): | |||||
pass | ||||||
|
||||||
|
||||||
class StaticRangeSpec(_BlockIdItem): | ||||||
def __init__( | ||||||
self, | ||||||
block_id: int, | ||||||
is_static: bool, | ||||||
) -> None: | ||||||
super().__init__([block_id]) | ||||||
self.is_static = is_static | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If We shouldn't add config values with only a single choice. |
||||||
|
||||||
def _fragment(self, base: ConfigSpec) -> ConfigSpecFragment: | ||||||
if ( | ||||||
self.is_static | ||||||
): # Only enable tl.static_range when loop parameters are static | ||||||
return BooleanFragment() | ||||||
return EnumFragment((False,)) | ||||||
|
||||||
def _normalize(self, name: str, value: object) -> bool: | ||||||
if not isinstance(value, bool): | ||||||
raise InvalidConfig(f"{name} must be a boolean, got {value!r}") | ||||||
if value is True and self.is_static is False: | ||||||
raise InvalidConfig( | ||||||
f"Got {name}=Ture for non-static loop #{self.block_id}\n Do you forget to call hl.specialize() on the loop dim? " | ||||||
) | ||||||
return value | ||||||
|
||||||
def _fill_missing(self) -> bool: | ||||||
"""Provide a value when not provided by the user.""" | ||||||
return False | ||||||
|
||||||
|
||||||
def _product(seq: Sequence[int]) -> int: | ||||||
"""Return the product of the elements in the sequence.""" | ||||||
return functools.reduce(operator.mul, seq, 1) |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -33,6 +33,7 @@ | |||||||
from ..autotuner.config_spec import RangeNumStagesSpec | ||||||||
from ..autotuner.config_spec import RangeUnrollFactorSpec | ||||||||
10000 | from ..autotuner.config_spec import RangeWarpSpecializeSpec | |||||||
from ..autotuner.config_spec import StaticRangeSpec | ||||||||
from . import _decorators | ||||||||
from helion.language.tile_proxy import Tile | ||||||||
|
||||||||
|
@@ -151,6 +152,23 @@ def _check_matching(a: object, b: object) -> None: | |||||||
) | ||||||||
|
||||||||
|
||||||||
def _is_constexpr_int(a: object) -> bool: | ||||||||
"""Check if the arg is specialized.""" | ||||||||
return isinstance(a, int) | ||||||||
# TODO(joydddd): render SymInt backed by Int as constexpr. | ||||||||
# Now the specialized constexpr is assigned to a dynamic variable first | ||||||||
# and then used as a variable. However args to static_range must be constexpr. | ||||||||
# e.g. | ||||||||
# hl.specialize(x.size(0)) | ||||||||
# for i in hl.grid(x.size(0)) | ||||||||
# -> | ||||||||
# symbol_0 = 64 | ||||||||
# for i in tl.static_range(symbol_0): | ||||||||
# | ||||||||
# if isinstance(a, torch.SymInt): | ||||||||
# return isinstance(a._sympy_(), sympy.Integer) | ||||||||
|
||||||||
|
||||||||
def _normalize_begin_end( | ||||||||
begin_or_end: TypeInfo, | ||||||||
end_or_none: TypeInfo | None, | ||||||||
|
@@ -225,6 +243,10 @@ def _( | |||||||
[x.block_id for x in results], | ||||||||
is_tile=True, | ||||||||
has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin), | ||||||||
is_static=all( | ||||||||
_is_constexpr_int(x) or x is None | ||||||||
for x in (*proxy_begin, *proxy_end, *proxy_block_size) | ||||||||
), | ||||||||
) | ||||||||
if unpack: | ||||||||
(result,) = results | ||||||||
|
@@ -234,7 +256,11 @@ def _( | |||||||
|
||||||||
|
||||||||
def _add_config_choices( | ||||||||
block_ids: list[int], *, is_tile: bool = False, has_begin: bool = False | ||||||||
block_ids: list[int], | ||||||||
*, | ||||||||
is_tile: bool = False, | ||||||||
has_begin: bool = False, | ||||||||
is_static: bool = False, | ||||||||
) -> None: | ||||||||
config_spec = CompileEnvironment.current().config_spec | ||||||||
|
||||||||
|
@@ -253,6 +279,7 @@ def _add_config_choices( | |||||||
else: | ||||||||
params = inspect.signature(triton.language.range).parameters | ||||||||
for block_id in block_ids: | ||||||||
config_spec.static_ranges.append(StaticRangeSpec(block_id, is_static)) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
if "loop_unroll_factor" in params: | ||||||||
config_spec.range_unroll_factors.append( | ||||||||
RangeUnrollFactorSpec([block_id]) | ||||||||
|
@@ -419,6 +446,10 @@ def _( | |||||||
[x.block_id for x in results], | ||||||||
is_tile=False, | ||||||||
has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin), | ||||||||
is_static=all( | ||||||||
_is_constexpr_int(x) or x is None | ||||||||
for x in (*proxy_begin, *proxy_end, *proxy_step) | ||||||||
), | ||||||||
) | ||||||||
if unpack: | ||||||||
(result,) = results | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should mention this is only legal for loops with static bounds