8000 Better types in codegen by inducer · Pull Request #941 · inducer/loopy · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Better types in codegen #941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63,452 changes: 27,550 additions & 35,902 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions doc/conf.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"islpy": ("https://documen.tician.de/islpy", None),
"pyopencl": ("https://documen.tician.de/pyopencl", None),
"cgen": ("https://documen.tician.de/cgen", None),
"genpy": ("https://documen.tician.de/genpy", None),
"pymbolic": ("https://documen.tician.de/pymbolic", None),
"constantdict": ("https://matthiasdiener.github.io/constantdict/", None),
}
Expand Down
5 changes: 5 additions & 0 deletions doc/ref_internals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ References

Mostly things that Sphinx (our documentation tool) should resolve but won't.

.. class:: ASTType

A type variable, representing an AST node. For now, either :class:`cgen.Generable`
or :class:`genpy.Generable`.

.. class:: constantdict

See :class:`constantdict.constantdict`.
Expand Down
18 changes: 9 additions & 9 deletions loopy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,16 @@
)
from loopy.kernel.array import (
ArrayBase,
ArrayDimImplementationTag,
FixedStrideArrayDimTag,
SeparateArrayArrayDimTag,
)
from loopy.kernel.data import (
AddressSpace,
ArrayArg,
ArrayDimImplementationTag,
AxisTag,
InameImplementationTag,
TemporaryVariable,
auto,
)
from loopy.kernel.function_interface import CallableKernel
from loopy.kernel.instruction import (
Expand All @@ -67,12 +66,13 @@
)
from loopy.symbolic import CombineMapper, ResolvedFunction, SubArrayRef, WalkMapper
from loopy.translation_unit import (
CallableId,
CallablesTable,
TranslationUnit,
check_each_kernel,
)
from loopy.type_inference import TypeReader
from loopy.typing import not_none
from loopy.typing import auto, not_none


if TYPE_CHECKING:
Expand Down Expand Up @@ -118,7 +118,7 @@

# {{{ sanity checks run before preprocessing

def check_identifiers_in_subst_rules(knl):
def check_identifiers_in_subst_rules(knl: LoopKernel):
"""Substitution rules may only refer to kernel-global quantities or their
own arguments.
"""
Expand All @@ -139,7 +139,7 @@ def check_identifiers_in_subst_rules(knl):
", ".join(deps-rule_allowed_identifiers)))


class UnresolvedCallCollector(CombineMapper):
class UnresolvedCallCollector(CombineMapper[frozenset[CallableId], []]):
"""
Collects all the unresolved calls within a kernel.

Expand Down Expand Up @@ -659,7 +659,7 @@ def check_for_data_dependent_parallel_bounds(kernel: LoopKernel) -> None:
from loopy.kernel.data import ConcurrentTag

for i, dom in enumerate(kernel.domains):
dom_inames = set(dom.get_var_names(dim_type.set))
dom_inames = set(dom.get_var_names_not_none(dim_type.set))
par_inames = {
iname for iname in dom_inames
if kernel.iname_tags_of_type(iname, ConcurrentTag)}
Expand Down Expand Up @@ -1938,7 +1938,7 @@ def check_implemented_domains(
non_lid_inames, [dim_type.set])

insn_domain = kernel.get_inames_domain(insn_inames)
insn_parameters = frozenset(insn_domain.get_var_names(dim_type.param))
insn_parameters = frozenset(insn_domain.get_var_names_not_none(dim_type.param))
assumptions, insn_domain = align_two(assumption_non_param, insn_domain)
desired_domain = ((insn_domain & assumptions)
.project_out_except(insn_inames, [dim_type.set])
Expand All @@ -1962,7 +1962,7 @@ def check_implemented_domains(
not_none(insn_domain.get_dim_name(dim_type.param, i))
for i in range(insn_impl_domain.dim(dim_type.param))}

lines = []
lines: list[str] = []
for bigger, smaller, diff_set, gist_domain in [
("implemented", "desired", i_minus_d,
desired_domain.gist(insn_impl_domain)),
Expand All @@ -1981,7 +1981,7 @@ def check_implemented_domains(
# lines.append("point desired: %s" % (pt_set <= desired_domain))

iname_to_dim = pt.get_space().get_var_dict()
point_axes = []
point_axes: list[str] = []
for iname in insn_inames | parameter_inames:
tp, dim = iname_to_dim[iname]
point_axes.append("%s=%d" % (
Expand Down
52 changes: 31 additions & 21 deletions loopy/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,22 @@
from pytools.persistent_dict import WriteOncePersistentDict

from loopy.diagnostic import LoopyError, warn
from loopy.kernel.function_interface import CallableKernel
from loopy.kernel.function_interface import CallableKernel, InKernelCallable
from loopy.tools import LoopyKeyBuilder, caches
from loopy.version import DATA_MODEL_VERSION


if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence

from pymbolic import Expression

from loopy.codegen.result import CodeGenerationResult, GeneratedProgram
from loopy.codegen.tools import CodegenOperationCacheManager
from loopy.kernel import LoopKernel
from loopy.library.reduction import ReductionOpFunction
from loopy.target import TargetBase
from loopy.translation_unit import CallablesTable, TranslationUnit
from loopy.target import ASTType, TargetBase
from loopy.translation_unit import CallableId, CallablesTable, TranslationUnit
from loopy.types import LoopyType


Expand Down Expand Up @@ -163,7 +163,7 @@ class CodeGenerationState:
i.e. all constraints that have been enforced so far.
"""

implemented_predicates: frozenset[str | Expression]
implemented_predicates: frozenset[Expression]

# /!\ mutable
seen_dtypes: set[LoopyType]
Expand Down Expand Up @@ -232,7 +232,11 @@ def fix(self, iname: str, aff: isl.Aff) -> CodeGenerationState:
return self.copy_and_assign(iname, expr).copy(
implemented_domain=new_impl_domain)

def try_vectorized(self, what, func):
def try_vectorized(self,
what: str,
func: Callable[[CodeGenerationState],
CodeGenerationResult[ASTType] | None]
):
"""If *self* is in a vectorizing state (:attr:`vectorization_info` is
not None), tries to call func (which must be a callable accepting a
single :class:`CodeGenerationState` argument). If this fails with
Expand All @@ -255,11 +259,14 @@ def try_vectorized(self, what, func):

return self.unvectorize(func)

def unvectorize(self, func):
def unvectorize(self,
func: Callable[[CodeGenerationState],
CodeGenerationResult[ASTType] | None],
):
vinf = self.vectorization_info
assert vinf is not None

result = []
result: list[CodeGenerationResult[ASTType]] = []
novec_self = self.copy(vectorization_info=None)

for i in range(vinf.length):
Expand All @@ -270,6 +277,8 @@ def unvectorize(self, func):

if isinstance(generated, list):
result.extend(generated)
elif generated is None:
pass
else:
result.append(generated)

Expand All @@ -288,7 +297,7 @@ def ast_builder(self):

code_gen_cache: WriteOncePersistentDict[
TranslationUnit,
CodeGenerationResult
CodeGenerationResult[Any]
] = WriteOncePersistentDict(
"loopy-code-gen-cache-v3-"+DATA_MODEL_VERSION,
key_builder=LoopyKeyBuilder(),
Expand Down Expand Up @@ -322,7 +331,7 @@ def generate_code_for_a_single_kernel(
callables_table: CallablesTable,
target: TargetBase,
is_entrypoint: bool,
) -> CodeGenerationResult:
) -> CodeGenerationResult[Any]:
"""
:returns: a :class:`CodeGenerationResult`

Expand Down Expand Up @@ -430,7 +439,7 @@ def generate_code_for_a_single_kernel(
return codegen_result


def diverge_callee_entrypoints(program):
def diverge_callee_entrypoints(t_unit: TranslationUnit):
"""
If a :class:`loopy.kernel.function_interface.CallableKernel` is both an
entrypoint and a callee, then rename the callee.
Expand All @@ -440,18 +449,19 @@ def diverge_callee_entrypoints(program):
make_callable_name_generator,
rename_resolved_functions_in_a_single_kernel,
)
callable_ids = get_reachable_resolved_callable_ids(program.callables_table,
program.entrypoints)
callable_ids = get_reachable_resolved_callable_ids(t_unit.callables_table,
t_unit.entrypoints)

new_callables = {}
todo_renames = {}
new_callables: dict[CallableId, InKernelCallable] = {}
todo_renames: dict[CallableId, str] = {}

vng = make_callable_name_generator(program.callables_table)
vng = make_callable_name_generator(t_unit.callables_table)

for clbl_id in callable_ids & program.entrypoints:
for clbl_id in callable_ids & t_unit.entrypoints:
assert isinstance(clbl_id, str)
todo_renames[clbl_id] = vng(based_on=clbl_id)

for name, clbl in program.callables_table.items():
for name, clbl in t_unit.callables_table.items():
if name in todo_renames:
name = todo_renames[name]

Expand All @@ -463,7 +473,7 @@ def diverge_callee_entrypoints(program):

new_callables[name] = clbl

return program.copy(callables_table=constantdict.constantdict(new_callables))
return t_unit.copy(callables_table=constantdict.constantdict(new_callables))


@dataclass(frozen=True)
Expand Down Expand Up @@ -528,7 +538,7 @@ def all_code(self):
self.host_programs.values()))


def generate_code_v2(t_unit: TranslationUnit) -> CodeGenerationResult:
def generate_code_v2(t_unit: TranslationUnit) -> CodeGenerationResult[Any]:
# {{{ cache retrieval

from loopy import ABORT_ON_CACHE_MISS, CACHING_ENABLED
Expand Down Expand Up @@ -660,7 +670,7 @@ def generate_code(kernel, device=None):

# {{{ generate function body

def generate_body(kernel):
def generate_body(kernel: TranslationUnit):
codegen_result = generate_code_v2(kernel)

if len(codegen_result.device_programs) != 1:
Expand Down
14 changes: 10 additions & 4 deletions loopy/codegen/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,22 @@


if TYPE_CHECKING:
from collections.abc import Collection

from loopy.codegen.tools import CodegenOperationCacheManager
from loopy.kernel import LoopKernel
from loopy.kernel.tools import SetOperationCacheManager
from loopy.typing import InameStr


# {{{ approximate, convex bounds check generator

def get_approximate_convex_bounds_checks(domain, check_inames,
implemented_domain, op_cache_manager):
if isinstance(domain, isl.BasicSet):
domain = isl.Set.from_basic_set(domain)
def get_approximate_convex_bounds_checks(
domain: isl.Set,
check_inames: Collection[InameStr],
implemented_domain: isl.Set,
op_cache_manager: SetOperationCacheManager
):
domain = domain.remove_redundancies()
result = op_cache_manager.eliminate_except(domain, check_inames,
(dim_type.set,))
Expand Down
Loading
Loading
0