-
Notifications
You must be signed in to change notification settings - Fork 13
Add static_range #235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add static_range #235
Conversation
stack-info: PR: #235, branch: joydddd/stack/9
stack-info: PR: pytorch-labs#235, branch: joydddd/stack/9
stack-info: PR: #235, branch: joydddd/stack/9
@@ -209,6 +209,10 @@ Contains one entry per loop dimension, controlling the `flatten` | |||
parameter for `tl.range()` calls. `True` sets `flatten=True`, | |||
`False` sets `flatten=False`, and `None` omits the parameter. | |||
|
|||
* **static\_ranges** (`list[bool]`): | |||
Contains one entry per loop dimension, controlling whether to use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should mention this is only legal for loops with static bounds
@@ -610,11 +620,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: | |||
) | |||
|
|||
range_extra = self.get_tl_range_kwargs(state, block_idx) | |||
range_fn = self.get_range_fn_name(state, self.block_ids[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be block_idx
? Or do we always want to match the first loop in an multidimension loopnest?
for_node = create( | ||
ast.For, | ||
target=create(ast.Name, id=offset_var, ctx=ast.Store()), | ||
iter=expr_from_string( | ||
f"tl.range(begin, end, {block_size_var}{range_extra})", | ||
f"{range_fn}(begin, end, {block_size_var}{range_extra})", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should refactor the {range_extra} helper to generate the entire expression (take in begin/end/step as args). That would put all the loop generation logic in one place.
): | ||
if config[name]: # The config is non empty | ||
# pyre-ignore[16] | ||
config[name][mapping.block_id_to_index(block_id)] = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The block_id may not exist in every one of these, since you can disable choices via the config_spec API. Handle the case of it missing.
Maybe add a mapping.set_config_to_default(config[name], block_id)
helper for this.
]: | ||
config[name] = mapping._normalize( | ||
name, config.get(name, ()), flatten=flatten | ||
) | ||
|
||
for block_id in self.static_ranges.valid_block_ids(): | ||
use_static_range = self.static_ranges.config_get( | ||
config.get("static_ranges", ()), # pyre-ignore[6] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a property so you don't need to typing ignore here:
config.get("static_ranges", ()), # pyre-ignore[6] | |
config.static_ranges |
is_static: bool, | ||
) -> None: | ||
super().__init__([block_id]) | ||
self.is_static = is_static |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If is_static
is False, don't create the StaticRangeSpec.
We shouldn't add config values with only a single choice.
@@ -253,6 +279,7 @@ def _add_config_choices( | |||
else: | |||
params = inspect.signature(triton.language.range).parameters | |||
for block_id in block_ids: | |||
config_spec.static_ranges.append(StaticRangeSpec(block_id, is_static)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config_spec.static_ranges.append(StaticRangeSpec(block_id, is_static)) | |
if is_static: | |
config_spec.static_ranges.append(StaticRangeSpec(block_id)) |
Stacked PRs:
Add static_range