Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new itir.Program everywhere #596

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tools/src/icon4pytools/common/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

@dataclass(frozen=True)
class StencilInfo:
fendef: itir.FencilDefinition
program: itir.Program
fields: dict[str, FieldInfo]
connectivity_chains: list[eve.concepts.SymbolRef]
offset_provider: dict
Expand Down Expand Up @@ -206,7 +206,7 @@ def get_stencil_info(
"""Generate StencilInfo dataclass from a fencil definition."""
fvprog = _get_fvprog(fencil_def)
offsets = _scan_for_offsets(fvprog)
fendef = fvprog.itir
program = fvprog.itir

fields = _get_field_infos(fvprog)

Expand All @@ -216,7 +216,7 @@ def get_stencil_info(
offset_provider[offset] = _provide_offset(offset, is_global)
if offset != dims.Koff.value:
connectivity_chains.append(offset)
return StencilInfo(fendef, fields, connectivity_chains, offset_provider)
return StencilInfo(program, fields, connectivity_chains, offset_provider)


def _calc_num_neighbors(dim_list: list[Dimension], includes_center: bool) -> int:
Expand Down
57 changes: 29 additions & 28 deletions tools/src/icon4pytools/icon4pygen/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from pathlib import Path
from typing import Any, Iterable, List

from gt4py.next.common import Connectivity, Dimension
from gt4py.next.common import Connectivity, Dimension, DimensionKind
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms import LiftMode
from gt4py.next.program_processors.codegens.gtfn import gtfn_module
from gt4py.next.type_system import type_specifications as ts
from icon4py.model.common import dimension as dims
Expand All @@ -33,26 +32,27 @@


def transform_and_configure_fencil(
fencil: itir.FencilDefinition,
) -> itir.FencilDefinition:
program: itir.Program,
) -> itir.Program:
"""Transform the domain representation and configure the FencilDefinition parameters."""
grid_size_symbols = [itir.Sym(id=arg, type=_SIZE_TYPE) for arg in GRID_SIZE_ARGS]

for closure in fencil.closures:
if not len(closure.domain.args) == 2:
raise TypeError(f"Output domain of '{fencil.id}' must be 2-dimensional.")
assert isinstance(closure.domain.args[0], itir.FunCall) and isinstance(
closure.domain.args[1], itir.FunCall
for stmt in program.body:
assert isinstance(stmt, itir.SetAt)
if not len(stmt.domain.args) == 2:
raise TypeError(f"Output domain of '{program.id}' must be 2-dimensional.")
assert isinstance(stmt.domain.args[0], itir.FunCall) and isinstance(
stmt.domain.args[1], itir.FunCall
)
horizontal_axis = closure.domain.args[0].args[0]
vertical_axis = closure.domain.args[1].args[0]
horizontal_axis = stmt.domain.args[0].args[0]
vertical_axis = stmt.domain.args[1].args[0]
assert isinstance(horizontal_axis, itir.AxisLiteral) and isinstance(
vertical_axis, itir.AxisLiteral
)
assert horizontal_axis.value in ["Vertex", "Edge", "Cell"]
assert vertical_axis.value == "K"

closure.domain = itir.FunCall(
stmt.domain = itir.FunCall(
fun=itir.SymRef(id="unstructured_domain"),
args=[
itir.FunCall(
Expand All @@ -66,7 +66,7 @@ def transform_and_configure_fencil(
itir.FunCall(
fun=itir.SymRef(id="named_range"),
args=[
itir.AxisLiteral(value=dims.Koff.source.value),
itir.AxisLiteral(value=dims.Koff.source.value, kind=DimensionKind.VERTICAL),
itir.SymRef(id=V_START),
itir.SymRef(id=V_END),
],
Expand All @@ -75,16 +75,18 @@ def transform_and_configure_fencil(
)

fencil_params = [
*(p for p in fencil.params if not is_size_param(p) and p not in grid_size_symbols),
*(p for p in get_missing_domain_params(fencil.params)),
*(p for p in program.params if not is_size_param(p) and p not in grid_size_symbols),
*(p for p in get_missing_domain_params(program.params)),
*grid_size_symbols,
]

return itir.FencilDefinition(
id=fencil.id,
function_definitions=fencil.function_definitions,
return itir.Program(
id=program.id,
function_definitions=program.function_definitions,
params=fencil_params,
closures=fencil.closures,
declarations=program.declarations,
body=program.body,
implicit_domain=program.implicit_domain
)


Expand All @@ -100,32 +102,32 @@ def get_missing_domain_params(params: List[itir.Sym]) -> Iterable[itir.Sym]:
return (itir.Sym(id=p, type=_SIZE_TYPE) for p in missing_args)


def check_for_domain_bounds(fencil: itir.FencilDefinition) -> None:
def check_for_domain_bounds(program: itir.Program) -> None:
"""Check that fencil params contain domain boundaries, emit warning otherwise."""
param_ids = {param.id for param in fencil.params}
param_ids = {param.id for param in program.params}
all_domain_params_present = all(
param in param_ids for param in [H_START, H_END, V_START, V_END]
)
if not all_domain_params_present:
warnings.warn(
f"Domain boundaries are missing or have non-standard names for '{fencil.id}'. "
f"Domain boundaries are missing or have non-standard names for '{program.id}'. "
"Adapting domain to use the standard names. This feature will be removed in the future.",
DeprecationWarning,
stacklevel=2,
)


def generate_gtheader(
fencil: itir.FencilDefinition,
program: itir.Program,
offset_provider: dict[str, Connectivity | Dimension],
imperative: bool,
temporaries: bool,
**kwargs: Any,
) -> str:
"""Generate a GridTools C++ header for a given stencil definition using specified configuration parameters."""
check_for_domain_bounds(fencil)
check_for_domain_bounds(program)

transformed_fencil = transform_and_configure_fencil(fencil)
transformed_fencil = transform_and_configure_fencil(program)

translation = gtfn_module.GTFNTranslationStep(
enable_itir_transforms=True,
Expand All @@ -134,7 +136,6 @@ def generate_gtheader(

if temporaries:
translation = translation.replace(
lift_mode=LiftMode.USE_TEMPORARIES,
symbolic_domain_sizes={
"Cell": "num_cells",
"Edge": "num_edges",
Expand All @@ -159,9 +160,9 @@ def __init__(self, stencil_info: StencilInfo) -> None:
def __call__(self, outpath: Path, imperative: bool, temporaries: bool) -> None:
"""Generate C++ code using the GTFN backend and write it to a file."""
gtheader = generate_gtheader(
fencil=self.stencil_info.fendef,
program=self.stencil_info.program,
offset_provider=self.stencil_info.offset_provider,
imperative=imperative,
temporaries=temporaries,
)
write_string(gtheader, outpath, f"{self.stencil_info.fendef.id}.hpp")
write_string(gtheader, outpath, f"{self.stencil_info.program.id}.hpp")
2 changes: 1 addition & 1 deletion tools/src/icon4pytools/icon4pygen/bindings/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class PyBindGen:
"""

def __init__(self, stencil_info: StencilInfo, levels_per_thread: int, block_size: int) -> None:
self.stencil_name = stencil_info.fendef.id
self.stencil_name = stencil_info.program.id
self.fields, self.offsets = self._stencil_info_to_binding_type(stencil_info)
self.levels_per_thread = levels_per_thread
self.block_size = block_size
Expand Down
Loading