8000 Pymbolic follow 2025 07 by inducer · Pull Request #944 · inducer/loopy · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Pymbolic follow 2025 07 #944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
736 changes: 35 additions & 701 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions loopy/auto_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def auto_test_vs_ref(
logger.info("%s (ref): run done" % ref_entrypoint)

ref_evt.wait()
ref_elapsed_event = 1e-9*(ref_evt.profile.END-ref_evt.profile.START)
ref_elapsed_event = 1e-9*(ref_evt.profile.end-ref_evt.profile.start)

break

Expand Down Expand Up @@ -649,8 +649,8 @@ def auto_test_vs_ref(
- 1e-9*events[0].profile.START) \
/ timing_rounds
try:
elapsed_event_marker = ((1e-9*evt_end.profile.START
- 1e-9*evt_start.profile.START)
elapsed_event_marker = ((1e-9*evt_end.profile.start
- 1e-9*evt_start.profile.start)
/ timing_rounds)
except cl.RuntimeError:
elapsed_event_marker = None
Expand Down
11 changes: 7 additions & 4 deletions loopy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

import islpy as isl
from islpy import dim_type
from pymbolic.primitives import Variable, is_arithmetic_expression
from pymbolic.primitives import AlgebraicLeaf, Variable, is_arithmetic_expression
from pytools import memoize_method

from loopy.diagnostic import (
Expand Down Expand Up @@ -79,6 +79,7 @@
from collections.abc import Mapping, Sequence

import pymbolic.primitives as p
from pymbolic import ArithmeticExpression
from pymbolic.typing import Expression

from loopy.kernel import LoopKernel
Expand Down Expand Up @@ -696,16 +697,18 @@ def _align_and_intersect_with_caller_assumption(callee_assumptions,
caller_assumptions)


def _mark_variables_from_caller(expr):
def _mark_variables_from_caller(expr: ArithmeticExpression):
import pymbolic.primitives as prim

from loopy.symbolic import SubstitutionMapper

def subst_func(x):
def subst_func(x: AlgebraicLeaf):
if isinstance(x, prim.Variable):
return prim.Variable(f"_lp_caller_{x.name}")

return SubstitutionMapper(subst_func)(expr)
res = SubstitutionMapper(subst_func)(expr)
assert is_arithmetic_expression(res)
return res

# }}}

Expand Down
8 changes: 6 additions & 2 deletions loopy/expression.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from pymbolic import ArithmeticExpression


__copyright__ = "Copyright (C) 2012-15 Andreas Kloeckner"

Expand All @@ -22,7 +24,7 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import numpy as np

Expand Down Expand Up @@ -121,7 +123,9 @@ def map_subscript(self, expr: p.Subscript) -> bool:

index = expr.index_tuple

index = tuple(simplify_using_aff(self.kernel, idx_i) for idx_i in index)
index = tuple(
simplify_using_aff(self.kernel, cast("ArithmeticExpression", idx_i))
for idx_i in index)

from pymbolic.primitives import Variable

Expand Down
14 changes: 8 additions & 6 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@ def convert_computed_to_fixed_dim_tags(

new_dim_tags = list(dim_tags)

stride_so_far: ArithmeticExpression | None

for target_axis in range(num_target_axes):
if vector_dim is None:
stride_so_far = 1
Expand Down Expand Up @@ -610,9 +612,9 @@ def convert_computed_to_fixed_dim_tags(
stride_so_far *= shape_axis

if dim_tag.pad_to is not None:
from pytools import div_ceil
assert stride_so_far is not None
stride_so_far = (
div_ceil(stride_so_far, dim_tag.pad_to)
-(-stride_so_far // dim_tag.pad_to)
* stride_so_far)

elif isinstance(dim_tag, FixedStrideArrayDimTag):
Expand Down Expand Up @@ -1238,10 +1240,10 @@ def get_strides(array: ArrayBase) -> tuple[Expression, ...]:
class AccessInfo(ImmutableRecord):
array_name: str
vector_index: int | None
subscripts: tuple[Expression, ...]
subscripts: tuple[ArithmeticExpression, ...]


def _apply_offset(sub: Expression, ary: ArrayBase) -> Expression:
def _apply_offset(sub: ArithmeticExpression, ary: ArrayBase) -> ArithmeticExpression:
"""
Helper for :func:`get_access_info`.
Augments *ary*'s subscript index expression (*sub*) with its offset info.
Expand Down Expand Up @@ -1278,7 +1280,7 @@ def _apply_offset(sub: Expression, ary: ArrayBase) -> Expression:

def get_access_info(kernel: LoopKernel,
ary: ArrayArg | TemporaryVariable,
index: Expression | tuple[Expression, ...],
index: ArithmeticExpression | tuple[ArithmeticExpression, ...],
eval_expr: Callable[[Expression], int],
vectorization_info: VectorizationInfo | None
) -> AccessInfo:
Expand Down Expand Up @@ -1334,7 +1336,7 @@ def eval_expr_assert_integer_constant(i, expr) -> int:
num_target_axes = ary.num_target_axes()

vector_index = None
subscripts: list[Expression] = [0] * num_target_axes
subscripts: list[ArithmeticExpression] = [0] * num_target_axes

vector_size = ary.vector_size(kernel.target)

Expand Down
29 changes: 19 additions & 10 deletions loopy/kernel/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import numpy as np
from constantdict import constantdict
from typing_extensions import override

import islpy as isl
from islpy import dim_type
Expand Down Expand Up @@ -62,7 +63,14 @@
from loopy.tools import Optional, intern_frozenset_of_ids
from loopy.translation_unit import TranslationUnit, for_each_kernel
from loopy.types import NumpyType
from loopy.typing import InameStr, PreambleGenerator, SymbolMangler, auto, not_none
from loopy.typing import (
InameStr,
PreambleGenerator,
SymbolMangler,
auto,
is_integer,
not_none,
)


if TYPE_CHECKING:
Expand All @@ -71,7 +79,7 @@

from numpy.typing import DTypeLike

from pymbolic import Expression
from pymbolic import ArithmeticExpression, Expression
from pytools.tag import ToTagSetConvertible

from loopy.options import Options
Expand Down Expand Up @@ -1873,16 +1881,13 @@ def apply_single_writer_dependency_heuristic(kernel, warn_if_used=True,

# {{{ slice to sub array ref

def normalize_slice_params(slice, dimension_length):
def normalize_slice_params(slice: Slice, dimension_length: ArithmeticExpression):
"""
Returns the normalized slice parameters ``(start, stop, step)``.

:arg slice: An instance of :class:`pymbolic.primitives.Slice`.
:arg dimension_length: Length of the axis being sliced.
"""
from numbers import Integral

from pymbolic.primitives import Slice

assert isinstance(slice, Slice)
start, stop, step = slice.start, slice.stop, slice.step
Expand All @@ -1909,14 +1914,14 @@ def normalize_slice_params(slice, dimension_length):

# }}}

if not isinstance(step, Integral):
if not is_integer(step):
raise LoopyError("Non-integral step sizes lead to non-affine domains =>"
" not supported")

return start, stop, step


class SliceToInameReplacer(IdentityMapper):
class SliceToInameReplacer(IdentityMapper[[]]):
"""
Converts slices to instances of :class:`loopy.symbolic.SubArrayRef`.

Expand Down Expand Up @@ -1944,9 +1949,12 @@ def __init__(self, knl):
self.var_name_gen = knl.get_var_name_generator()
super().__init__()

def map_subscript(self, expr):
@override
def map_subscript(self, expr: Subscript):
subscript_iname_bounds = {}

assert isinstance(expr.aggregate, Variable)

new_index = []
swept_inames = []
for i, index in enumerate(expr.index_tuple):
Expand Down Expand Up @@ -1982,7 +1990,8 @@ def map_subscript(self, expr):

return result

def map_call(self, expr):
@override
def map_call(self, expr: Call):

def _convert_array_to_slices(arg):
# FIXME: We do not support something like A[1] should point to the
Expand Down
7 changes: 5 additions & 2 deletions loopy/library/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
THE SOFTWARE.
"""

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import numpy as np
from constantdict import constantdict
Expand All @@ -34,6 +34,8 @@


if TYPE_CHECKING:
from pymbolic import ArithmeticExpression

from loopy.translation_unit import CallablesTable


Expand Down Expand Up @@ -84,7 +86,8 @@ def emit_call(self, expression_to_code_mapper, expression, target):

from loopy.kernel.array import get_access_info
access_info = get_access_info(expression_to_code_mapper.kernel,
ary, arg.index, lambda expr: evaluate(expr,
ary, cast("ArithmeticExpression", arg.index),
lambda expr: evaluate(expr,
expression_to_code_mapper.codegen_state.var_subst_map),
expression_to_code_mapper.codegen_state.vectorization_info)

Expand Down
10 changes: 9 additions & 1 deletion loopy/library/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from typing_extensions import override

from pymbolic import ArithmeticExpression, Expression, var
from pymbolic.primitives import expr_dataclass
from pymbolic.primitives import expr_dataclass, is_arithmetic_expression

from loopy.diagnostic import LoopyError
from loopy.kernel.function_interface import ScalarCallable
Expand Down Expand Up @@ -205,6 +205,10 @@ def __call__(self,
) -> tuple[Expression, CallablesTable]:
assert not isinstance(operand1, tuple)
assert not isinstance(operand2, tuple)
if not is_arithmetic_expression(operand1):
raise ValueError("operand 1 must be arithmetic")
if not is_arithmetic_expression(operand2):
raise ValueError("operand 2 must be arithmetic")
return operand1 + operand2, callables_table


Expand Down Expand Up @@ -232,6 +236,10 @@ def __call__(self,
) -> tuple[Expression, CallablesTable]:
assert not isinstance(operand1, tuple)
assert not isinstance(operand2, tuple)
if not is_arithmetic_expression(operand1):
raise ValueError("operand 1 must be arithmetic")
if not is_arithmetic_expression(operand2):
raise ValueError("operand 2 must be arithmetic")
return operand1 * operand2, callables_table


Expand Down
22 changes: 13 additions & 9 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
)
from pymbolic.mapper.dependency import (
CachedDependencyMapper as DependencyMapperBase,
DependenciesT,
)
from pymbolic.mapper.evaluator import CachedEvaluationMapper as EvaluationMapperBase
from pymbolic.mapper.flattener import FlattenMapper as FlattenMapperBase
Expand All @@ -93,6 +92,7 @@
if TYPE_CHECKING:
from collections.abc import Callable, Collection, Iterable, Mapping, Sequence

from pymbolic.mapper.dependency import Dependencies
from pymbolic.typing import ArithmeticOrExpressionT

from loopy.kernel import LoopKernel
Expand Down Expand Up @@ -149,6 +149,10 @@

See :data:`pymbolic.typing.Expression`.

.. class:: ArithmeticExpression

See :data:`pymbolic.ArithmeticExpression`.

.. class:: _Expression

See :class:`pymbolic.primitives.ExpressionNode`.
Expand Down Expand Up @@ -506,34 +510,34 @@ class DependencyMapper(DependencyMapperBase[P]):
def map_group_hw_index(
self,
expr: GroupHardwareAxisIndex, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
) -> Dependencies:
return set()

def map_local_hw_index(
self,
expr: LocalHardwareAxisIndex, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
) -> Dependencies:
return set()

def map_call(
self,
expr: p.Call, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
) -> Dependencies:
# Loopy does not have first-class functions. Do not descend
# into 'function' attribute of Call.
return self.rec(expr.parameters, *args, **kwargs)

def map_reduction(
self,
expr: Reduction, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
) -> Dependencies:
deps = self.rec(expr.expr, *args, **kwargs)
return deps - {Variable(iname) for iname in expr.inames}

def map_tagged_variable(
self,
expr: TaggedVariable, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
) -> Dependencies:
return {expr}

def map_loopy_function_identifier(self, expr, *args: P.args, **kwargs: P.kwargs):
Expand Down Expand Up @@ -1069,7 +1073,7 @@ def __init__(self, *args, **kwargs) -> None:
def map_reduction(
self,
expr: Reduction, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
) -> Dependencies:
self.reduction_inames.update(expr.inames)
return super().map_reduction(expr, *args, **kwargs)

Expand Down Expand Up @@ -2165,8 +2169,8 @@ def simplify_via_aff(expr):
@memoize_on_first_arg
def simplify_using_aff(
kernel: LoopKernel,
expr: Expression
) -> Expression:
expr: ArithmeticExpression
) -> ArithmeticExpression:
"""
Simplifies *expr* on *kernel*'s domain.

Expand Down
Loading
Loading
0