8000 Add static_range by joydddd · Pull Request #235 · pytorch-labs/helion · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add static_range #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Add static_range #235

wants to merge 1 commit into from

Conversation

joydddd
Copy link
Contributor
@joydddd joydddd commented Jul 3, 2025

joydddd added a commit that referenced this pull request Jul 3, 2025
stack-info: PR: #235, branch: joydddd/stack/9
@joydddd joydddd force-pushed the joydddd/stack/9 branch from f8caad0 to 8668247 Compare July 3, 2025 00:10
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 3, 2025
@joydddd joydddd force-pushed the joydddd/stack/9 branch from 8668247 to 3975adc Compare July 3, 2025 00:13
joydddd added a commit to joydddd/helion that referenced this pull request Jul 3, 2025
stack-info: PR: pytorch-labs#235, branch: joydddd/stack/9
stack-info: PR: #235, branch: joydddd/stack/9
@joydddd joydddd force-pushed the joydddd/stack/9 branch from 3975adc to 4e3c357 Compare July 3, 2025 18:55
@@ -209,6 +209,10 @@ Contains one entry per loop dimension, controlling the `flatten`
parameter for `tl.range()` calls. `True` sets `flatten=True`,
`False` sets `flatten=False`, and `None` omits the parameter.

* **static\_ranges** (`list[bool]`):
Contains one entry per loop dimension, controlling whether to use
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should mention this is only legal for loops with static bounds

@@ -610,11 +620,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
)

range_extra = self.get_tl_range_kwargs(state, block_idx)
range_fn = self.get_range_fn_name(state, self.block_ids[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be block_idx? Or do we always want to match the first loop in an multidimension loopnest?

for_node = create(
ast.For,
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
iter=expr_from_string(
f"tl.range(begin, end, {block_size_var}{range_extra})",
f"{range_fn}(begin, end, {block_size_var}{range_extra})",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should refactor the {range_extra} helper to generate the entire expression (take in begin/end/step as args). That would put all the loop generation logic in one place.

):
if config[name]: # The config is non empty
# pyre-ignore[16]
config[name][mapping.block_id_to_index(block_id)] = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The block_id may not exist in every one of these, since you can disable choices via the config_spec API. Handle the case of it missing.

Maybe add a mapping.set_config_to_default(config[name], block_id) helper for this.

]:
config[name] = mapping._normalize(
name, config.get(name, ()), flatten=flatten
)

for block_id in self.static_ranges.valid_block_ids():
use_static_range = self.static_ranges.config_get(
config.get("static_ranges", ()), # pyre-ignore[6]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a property so you don't need to typing ignore here:

Suggested change
config.get("static_ranges", ()), # pyre-ignore[6]
config.static_ranges

is_static: bool,
) -> None:
super().__init__([block_id])
self.is_static = is_static
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If is_static is False, don't create the StaticRangeSpec.

We shouldn't add config values with only a single choice.

@@ -253,6 +279,7 @@ def _add_config_choices(
else:
params = inspect.signature(triton.language.range).parameters
for block_id in block_ids:
config_spec.static_ranges.append(StaticRangeSpec(block_id, is_static))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
config_spec.static_ranges.append(StaticRangeSpec(block_id, is_static))
if is_static:
config_spec.static_ranges.append(StaticRangeSpec(block_id))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0