Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: SDFGConvertible Program for dace_fieldview backend #1742

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

DropD
Copy link
Contributor

@DropD DropD commented Nov 19, 2024

Description

Add a decrator.Program subclass, which implements SDFGConvertible to dace_fieldview backend, analogous to the one in dace_iterator. Conditionally shadow decorator.Program with it and reactivate the orchestration tests by using dace_fieldview instead of dace_iterator.

One caveat: The toolchain is not ready for pure CompileTimeConnectivities in all cases yet, so the test_sdfgConvertible_connectivities had to be adjusted for the moment.

Requirements

  • All fixes and/or new features come with corresponding tests.
  • Important design decisions have been documented in the approriate ADR inside the docs/development/ADRs/ folder.

If this PR contains code authored by new contributors please make sure:

  • The PR contains an updated version of the AUTHORS.md file adding the names of all the new contributors.

src/gt4py/next/ffront/decorator.py Show resolved Hide resolved
if conn_id not in self.connectivity_tables_data_descriptors:
conn = self.connectivities[name]
self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array(
dtype=dace.int64 if conn.index_type == np.int64 else dace.int32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that dace dtypes has a utility to parse numpy types:
dace.dtypes.typeclass(conn.index_type)

Copy link
Contributor Author

@DropD DropD Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying that lead to KeyError: dtype('int64') in typeclass().

Copy link
Contributor Author

@DropD DropD Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and trying dace.dtypes.typeclass(conn.index_type.type) leads to wrong stencil results. Edit: this might have instead been caused by auto_optimize=False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edit: after merging main, it passes. However according to type hints of NeighborTableOffsetProvider, your version should work and mine should fail. Yet, in the tests, the connectivities are set up in a way so that my version works and yours fails. The current version on the other hand is safe in both cases.

If I can get the test to produce the correct types so the official version works, I will change it. Otherwise I would prefer to leave this one until the neighbor table construction is anyway changed by what @havogt is working on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could not get it to work, it seems like the type hints are wrong, but there is no point investing time in this when the API will change heavily anyway.

@kotsaloscv kotsaloscv self-requested a review November 20, 2024 09:34
for in_field in closure.inputs
if str(in_field.id) in fields # type: ignore[union-attr] # ensured by assert
]
sdfg.gt4py_program_input_fields = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please port this dynamic property to GTIR as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on it, Program nodes will require a different approach.

for arg in output.args:
if str(arg.id) in fields: # type: ignore[attr-defined]
output_fields.append(str(arg.id)) # type: ignore[attr-defined]
sdfg.gt4py_program_output_fields = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for this property.

Copy link
Contributor Author

@DropD DropD Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*edit, moved this comment to the offset_provider_per_input_field property below.

continue
if param.id not in sdfg.gt4py_program_input_fields:
continue
sdfg.offset_providers_per_input_field.setdefault(param.id, []).extend(list(shifts))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

All these properties are needed for the automatic halo exchange placement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Results from discussion in a separate channel for "offset_providers_per_input_field":

  • it is correct, as only used in icon4py which uses "unstructured" fields. There is max one horizontal dimension per field.
  • trace_shifts should work™️ partially on GTIR (but not accross as_fieldop).
  • currently not actively used in icon4py, so defer implementation and put a TODO instead.

return (args, [])


def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function also needs to be ported to GTIR as well.

# Add them as dynamic properties to the SDFG
program = typing.cast(
itir.Program, gtir_stage.data
) # we already checked that our backend uses GTIR
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already checked that our backend uses GTIR

where? otherwise we could make an assert isinstance().

}

input_fields = (field_params[name] for name in InputNamesExtractor.only_fields(program))
# TODO (ricoh): This will associate the last horizontal dimension of each field with it's name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suspect is that this works under the assumption that all fields have a single horizontal dimension.
@kotsaloscv If this is the case, should we make a check and raise an exception if this assumption is not met?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants