8000 Add the core properties to Config object by drisspg · Pull Request #49 · pytorch-labs/helion · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add the core properties to Config object #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.9
hooks:
# Run the linter.
- id: ruff
args: [--fix]
# Run the formatter.
- id: ruff-format
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ refers to hydrogen-3.

[Triton]: https://github.com/triton-lang/triton

> ⚠️ **Early Development Warning**
> ⚠️ **Early Development Warning**
> Helion is currently in an experimental stage. You should expect bugs, incomplete features, and APIs that may change in future versions. Feedback and bug reports are welcome and appreciated!

## Example
Expand All @@ -27,13 +27,13 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
m, k = x.size()
k, n = y.size()
out = torch.empty([m, n], dtype=x.dtype, device=x.device)

for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
out[tile_m, tile_n] = acc

return out
```

Expand Down Expand Up @@ -204,7 +204,8 @@ Alternatively, you may install from source for development purposes:
```bash
git clone https://github.com/pytorch-labs/helion.git
cd helion
python setup.py develop
# To install in editable w/ required dev packages
pip install -e .'[dev]'
````
This installs Helion in "editable" mode so that changes to the source
code take effect without needing to reinstall.
Expand Down
2 changes: 1 addition & 1 deletion helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def code_and_output(
**kwargs: object,
) -> tuple[str, object]:
if kwargs:
config = Config(**kwargs)
config = Config(**kwargs) # pyre-ignore[6]
elif fn.configs:
(config,) = fn.configs
else:
Expand Down
2 changes: 1 addition & 1 deletion helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
use_yz_grid = fn(BooleanFragment())
if config.get("l2_grouping", 1) == 1 and isinstance(block_sizes[0], list):
config["use_yz_grid"] = use_yz_grid
return helion.Config(config)
return helion.Config(**config) # pyre-ignore[6]


class BlockSizeSpec:
Expand Down
57 changes: 47 additions & 10 deletions helion/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,50 @@
class Config(Mapping[str, object]):
config: dict[str, object]

def __init__(self, config: object = None, **kwargs: object) -> None:
if config is not None:
assert not kwargs
assert isinstance(config, (dict, Config))
self.config = {**config}
else:
self.config = kwargs
def __init__(
self,
*,
# Core properties
block_sizes: list[int | list[int]] | None = None,
loop_orders: list[list[int]] | None = None,
reduction_loops: list[int | None] | None = None,
num_warps: int | None = None,
num_stages: int | None = None,
l2_grouping: int | None = None,
use_yz_grid: bool | None = None,
indexing: IndexingLiteral | None = None,
# For user-defined properties
**kwargs: object,
) -> None:
"""
Initialize a Config object.

Args:
block_sizes: Controls tile sizes for hl.tile invocations.
loop_orders: Permutes iteration order of tiles.
reduction_loops: Configures reduction loop behavior.
num_warps: Number of warps per block.
num_stages: Number of stages for software pipelining.
l2_grouping: Reorders program IDs for L2 cache locality.
use_yz_grid: Whether to use yz grid dimensions.
indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr").
**kwargs: Additional user-defined configuration parameters.
"""
self.config = {}
core_props = {
"block_sizes": block_sizes,
"loop_orders": loop_orders,
"reduction_loops": reduction_loops,
"num_warps": num_warps,
"num_stages": num_stages,
"indexing": indexing,
"l2_grouping": l2_grouping,
"use_yz_grid": use_yz_grid,
}
for key, value in core_props.items():
if value is not None:
self.config[key] = value
self.config.update(kwargs)

def __getitem__(self, key: str) -> object:
return self.config[key]
Expand Down Expand Up @@ -56,7 +93,7 @@ def to_json(self) -> str:
def from_json(cls, json_str: str) -> Config:
"""Create a Config object from a JSON string."""
config_dict = json.loads(json_str)
return cls(config_dict)
return cls(**config_dict) # Changed to use dictionary unpacking

def save(self, path: str | Path) -> None:
"""Save the config to a JSON file."""
Expand Down Expand Up @@ -92,12 +129,12 @@ def l2_grouping(self) -> int:
return cast("int", self.config.get("l2_grouping", 1))

@property
def use_yz_grid(self) -> int:
def use_yz_grid(self) -> bool:
return cast("bool", self.config.get("use_yz_grid", False))

@property
def indexing(self) -> IndexingLiteral:
return cast("IndexingLiteral", self.config.get("indexing", "pointer"))
return self.config.get("indexing", "pointer") # type: ignore


def _list_to_tuple(x: object) -> object:
Expand Down
12 changes: 8 additions & 4 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def __init__(
self.fn = fn
self.signature: inspect.Signature = inspect.signature(fn)
self.settings: Settings = settings or Settings.default()
self.configs: list[Config] = [*map(Config, configs or ())]
self.configs: list[Config] = [
Config(**c) if isinstance(c, dict) else c for c in configs or []
]
# pyre-fixme[11]: BoundKernel undefined?
self._bound_kernels: dict[Hashable, BoundKernel] = {}
if any(
Expand Down Expand Up @@ -295,7 +297,9 @@ def to_triton_code(self, config: ConfigLike) -> str:
:rtype: str
"""
with self.env:
config = Config(config)
if not isinstance(config, Config):
# pyre-ignore[6]
config = Config(**config)
self.env.config_spec.normalize(config)
root = generate_ast(self.host_fn, config)
return get_needed_imports(root) + unparse(root)
Expand All @@ -310,7 +314,7 @@ def compile_config(self, config: ConfigLike) -> CompiledConfig:
:rtype: Callable[..., object]
"""
if not isinstance(config, Config):
config = Config(config)
config = Config(**config) # pyre-ignore[6]
if (rv := self._compile_cache.get(config)) is not None:
return rv
triton_code = self.to_triton_code(config)
Expand Down Expand Up @@ -375,7 +379,7 @@ def set_config(self, config: ConfigLike) -> None:
:type config: ConfigLike
"""
if not isinstance(config, Config):
config = Config(config)
config = Config(**config) # pyre-ignore[6]
self._run = self.compile_config(config)

def __call__(self, *args: object) -> object:
Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ dependencies = [
"typing_extensions>=4.0.0",
]

[project.optional-dependencies]
dev = [
"expecttest",
"pytest",
"pre-commit"
]

[project.urls]
Homepage = "https://github.com/pytorch-labs/helion"
Issues = "https://github.com/pytorch-labs/helion/issues"
Expand Down Expand Up @@ -67,4 +74,3 @@ force-sort-within-sections = true

[tool.setuptools]
license-files = ["LICENSE"]

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
expecttest
pytest
typing_extensions
pre-commit
Loading
< 31C2 /diff-file-filter>
0