8000 Support namedtuple and dataclass by oulgen · Pull Request #41 · pytorch-labs/helion · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Support namedtuple and dataclass #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,26 @@ def to_fake(self, obj: object, origin: Origin) -> object:
return obj.value
if isinstance(obj, list):
return [self.to_fake(e, origin) for e in obj]
if isinstance(obj, tuple) and hasattr(obj, "_fields"):
return type(obj)(
**{ # pyre-ignore[6]
k: self.to_fake(e, origin)
for k, e in obj._asdict().items() # pyre-ignore[16]
}
)
if isinstance(obj, tuple):
return tuple(self.to_fake(e, origin) for e in obj)
if isinstance(obj, dict):
return {k: self.to_fake(e, origin) for k, e in obj.items()}
# TODO(jansel): support other types of args
if dataclasses.is_dataclass(obj):
return dataclasses.replace(
obj,
**{
k: self.to_fake(getattr(obj, k), origin)
for k in obj.__dataclass_fields__ # pyre-ignore[16]
},
)

raise TypeError(f"unsupported argument type {type(obj)} ({origin})")

def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
Expand Down
35 changes: 35 additions & 0 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,36 @@ def from_example(cls, value: object, origin: Origin) -> TypeInfo:
zip(value.keys(), cls._unpack_example(items, origin), strict=False)
),
)
if isinstance(value, tuple) and hasattr(value, "_asdict"):
# namedtuple
return ClassType(
origin,
dict(
zip(
value._fields, # pyre-ignore[16]
cls._unpack_example(
value._asdict().items(), # pyre-ignore[16]
origin,
),
strict=False,
)
),
)
if dataclasses.is_dataclass(value):
keys = value.__dataclass_fields__.keys() # pyre-ignore[16]
return ClassType(
origin,
dict(
zip(
keys,
cls._unpack_example(
tuple((k, getattr(value, k)) for k in keys),
origin,
),
strict=False,
)
),
)
return UnknownType(
debug_msg=f"{type(value).__name__} is not supported",
origin=origin,
Expand Down Expand Up @@ -1122,6 +1152,11 @@ def tree_map(self, fn: Callable[[TypeInfo], object]) -> dict[str | int, object]:
return {k: v.tree_map(fn) for k, v in self.element_types.items()}


class ClassType(DictType):
def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
return self.element_types[attr]


class SliceType(CollectionType):
element_types: slice

Expand Down
31 changes: 24 additions & 7 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections.abc import Callable
import dataclasses
import functools
import inspect
import logging
Expand Down Expand Up @@ -140,9 +141,15 @@ def _specialization_key(self, obj: object) -> Hashable:
try:
extractor = _specialization_extractors[type(obj)]
except KeyError:
raise TypeError(
f"unsupported argument type: {type(obj).__name__}"
) from None
if isinstance(obj, tuple) and hasattr(obj, "_fields"):
# this is a namedtuple
extractor = _specialization_extractors["namedtuple"]
elif dataclasses.is_dataclass(obj):
extractor = _specialization_extractors["dataclass"]
else:
raise TypeError(
f"unsupported argument type: {type(obj).__name__}"
) from None
return extractor(self, obj)

def normalize_args(self, *args: object, **kwargs: object) -> tuple[object, ...]:
Expand Down Expand Up @@ -462,6 +469,14 @@ def _sequence_key(fn: Kernel, obj: Sequence) -> Hashable:
return type(obj), tuple([fn._specialization_key(item) for item in obj])


def _mapping_key(
fn: Kernel, obj: dict[str | int, object], real_type: type[object]
) -> Hashable:
return real_type, tuple(
sorted((k, fn._specialization_key(v)) for k, v in obj.items())
)


def _number_key(fn: Kernel, n: float | bool) -> object:
return type(n)

Expand All @@ -475,7 +490,9 @@ def _function_key(fn: Kernel, obj: types.FunctionType) -> object:
return obj.__code__


_specialization_extractors: dict[type[object], Callable[[Kernel, object], Hashable]] = {
_specialization_extractors: dict[
type[object] | str, Callable[[Kernel, object], Hashable]
] = {
torch.Tensor: _tensor_key,
torch.nn.Parameter: _tensor_key,
torch.dtype: lambda fn, x: x,
Expand All @@ -486,9 +503,9 @@ def _function_key(fn: Kernel, obj: types.FunctionType) -> object:
str: lambda fn, x: x,
list: _sequence_key,
tuple: _sequence_key,
dict: lambda fn, x: tuple(
sorted((k, fn._specialization_key(v)) for k, v in x.items())
),
dict: lambda fn, x: _mapping_key(fn, x, type(x)),
"namedtuple": lambda fn, x: _mapping_key(fn, x._asdict(), type(x)),
"dataclass": lambda fn, x: _mapping_key(fn, dataclasses.asdict(x), type(x)),
types.FunctionType: _function_key,
types.BuiltinFunctionType: lambda fn, x: x,
ConstExpr: lambda fn, x: x.value,
Expand Down
69 changes: 48 additions & 21 deletions test/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from collections import namedtuple
from dataclasses import dataclass
import unittest

from expecttest import TestCase
Expand Down Expand Up @@ -58,20 +60,33 @@ def add3(x, y):

def test_inputs(self):
@helion.kernel
def kernel(a_list, b_dict, b_tuple):
def kernel(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass):
a0, a1 = a_list
b0 = b_dict["b0"]
(b1,) = b_tuple
c0, c1 = torch.empty_like(a0), torch.empty_like(a1)
c0, c1 = c_named_tuple.x, c_named_tuple.y
d0, d1 = d_dataclass.x, d_dataclass.y

o0, o1 = torch.empty_like(a0), torch.empty_like(a1)
for tile in hl.tile(a0.size()):
c0[tile] = a0[tile] + b0[tile]
c1[tile] = a1[tile] + b1[tile]
return [c0, c1]
o0[tile] = a0[tile] + b0[tile] + c0[tile] + d0[tile]
o1[tile] = a1[tile] + b1[tile] + c1[tile] + d1[tile]
return [o0, o1]

x = torch.randn(4, device=DEVICE)
code, result = code_and_output(kernel, ([x, x], {"b0": x}, (x,)))
torch.testing.assert_close(result[0], 2 * x)
torch.testing.assert_close(result[1], 2 * x)
x = torch.ones(4, device=DEVICE)
Point = namedtuple("Point", ["x", "y"]) # noqa: PYI024
p = Point(x, x)

@dataclass(frozen=True)
class Point2:
x: torch.Tensor
y: torch.Tensor

p2 = Point2(x, x)

code, result = code_and_output(kernel, ([x, x], {"b0": x}, (x,), p, p2))
torch.testing.assert_close(result[0], 4 * x)
torch.testing.assert_close(result[1], 4 * x)
self.assertExpectedInline(
code,
"""\
Expand All @@ -82,37 +97,49 @@ def kernel(a_list, b_dict, b_tuple):
import triton.language as tl

@triton.jit
def _kernel_kernel(a0, c0, c1, a0_size_0, a0_stride_0, c0_stride_0, c1_stride_0, _BLOCK_SIZE_0: tl.constexpr):
def _kernel_kernel(a0, o0, o1, a0_size_0, a0_stride_0, o0_stride_0, o1_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_0 < a0_size_0
load = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
load_1 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
v_0 = load + load_1
tl.store(c0 + indices_0 * c0_stride_0, v_0, mask_0)
load_2 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
v_1 = v_0 + load_2
load_3 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
v_1 = load_2 + load_3
tl.store(c1 + indices_0 * c1_stride_0, v_1, mask_0)

def kernel(a_list, b_dict, b_tuple):
v_2 = v_1 + load_3
tl.store(o0 + indices_0 * o0_stride_0, v_2, mask_0)
load_4 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
load_5 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
v_3 = load_4 + load_5
load_6 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
v_4 = v_3 + load_6
load_7 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
v_5 = v_4 + load_7
tl.store(o1 + indices_0 * o1_stride_0, v_5, mask_0)

def kernel(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass):
a0, a1 = a_list
b0 = b_dict['b0']
b1, = b_tuple
c0, c1 = (torch.empty_like(a0), torch.empty_like(a1))
c0, c1 = (c_named_tuple.x, c_named_tuple.y)
d0, d1 = (d_dataclass.x, d_dataclass.y)
o0, o1 = (torch.empty_like(a0), torch.empty_like(a1))
_BLOCK_SIZE_0 = 4
_kernel_kernel[triton.cdiv(a0.size(0), _BLOCK_SIZE_0),](a0, c0, c1, a0.size(0), a0.stride(0), c0.stride(0), c1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return [c0, c1]
_kernel_kernel[triton.cdiv(a0.size(0), _BLOCK_SIZE_0),](a0, o0, o1, a0.size(0), a0.stride(0), o0.stride(0), o1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return [o0, o1]

def _kernel_make_precompiler(a_list, b_dict, b_tuple):
def _kernel_make_precompiler(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass):
a0, a1 = a_list
b0 = b_dict['b0']
b1, = b_tuple
c0, c1 = (torch.empty_like(a0), torch.empty_like(a1))
c0, c1 = (c_named_tuple.x, c_named_tuple.y)
d0, d1 = (d_dataclass.x, d_dataclass.y)
o0, o1 = (torch.empty_like(a0), torch.empty_like(a1))
_BLOCK_SIZE_0 = 4
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_kernel_kernel)(a0, c0, c1, a0.size(0), a0.stride(0), c0.stride(0), c1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
return make_precompiler(_kernel_kernel)(a0, o0, o1, a0.size(0), a0.stride(0), o0.stride(0), o1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
)


Expand Down
Loading
0