Skip to content

compiler: Rework /petsc/iet and PETScArrays#38

Merged
ZoeLeibowitz merged 20 commits intoFieldFromPointerfrom
cleanup2
Feb 18, 2025
Merged

compiler: Rework /petsc/iet and PETScArrays#38
ZoeLeibowitz merged 20 commits intoFieldFromPointerfrom
cleanup2

Conversation

@ZoeLeibowitz
Copy link
Owner

@ZoeLeibowitz ZoeLeibowitz commented Jan 9, 2025

  • Refactors petsc/iet by introducing more classes to reduce the number of free functions, and enable flexibility to accommodate different types of solvers with varying setups, callbacks, and other configurations etc. (many free functions inside petsc/iet/passes.py have moved into new classes inside petsc/iet/routines.py)
  • Simplifies the PETScArray class, as it is consistently constructed from Function-like objects.
  • Ensure that a unique DMDA and matrix context struct are created for each SNES solve.

@ZoeLeibowitz ZoeLeibowitz marked this pull request as ready for review January 10, 2025 08:13
Copy link
Collaborator

@mloubout mloubout left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly minor comments.

The iet layer could use more docstring and comments for maintainability but it can wait for later updates.

for iters, (injectsolve,) in injectsolve_mapper.items():

builder = PETScCallbackBuilder(**kwargs)
builder_classes = get_builder_classes(injectsolve)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could make it an object or namedtuple so you can can have buillder.objbuilder and such, might make it easier to extend/maintain as well.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now created the Builder class inside petsc/iet/passes to encapsulate this. What do you think?

solver_objs = build_solver_objs(injectsolve, iters, **kwargs)
time_dep = dep(injectsolve, **kwargs)

solver_objs = ObjBuilder(dep=time_dep, **kwargs).build(injectsolve, iters)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They all seem to take the same input so you'd be able to just init the builder with all its components at once

subs.update({space_iter: List(body=runsolve)})
# Transform the spatial iteration loop with the calls to execute the solver
space_iter, = spatial_loop_nest(iters, injectsolve)
runsolve = SolverRun(dep=time_dep).runsolve(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personal nitpicking: i'd avoid using "run"

mapper = {solver_objs['dummyctx']: solver_objs['localctx']}
return Uxreplace(mapper).visit(efunc)

return {k: replace(v) for k, v in efuncs.items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why define an inner function instead of a plain loop

mapper = {}
vistor  = Uxreplace({solver_objs['dummyctx']: solver_objs['localctx']})
for k, v in efuncs.items():
    mapper.update(vistor.visit(v))
    
return mapper

Either way Uxreplace(mapper) is the same for all efuncs so should only need to make it once

obj = object.__new__(cls)
obj.rcompile = rcompile
obj.sregistry = sregistry
obj.concretize_mapper = kwargs.get('concretize_mapper', {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is tied to the OSS PR right?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've actually dropped it now since I use the same SymbolRegistry throughout, even with rcompile

}


class SetupSolver:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this really need to be a class if it's stateless instead of three plain functions?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now added a hook for subclasses to add additional setup calls

if not time_iter:
return mapper

for d in time_iter[0].dimensions:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only time_iter[0] ?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've corrected this now


@classmethod
def __indices_setup__(cls, *args, **kwargs):
dimensions = kwargs['target'].space_dimensions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just dimensions ?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PETScArray objects correspond to PETSc Vec objects, which represent the spatial domain and are used at each time step.

'ksp_type': 'gmres',
'pc_type': 'jacobi',
'ksp_rtol': 1e-7, # Relative tolerance
'ksp_rtol': 1e-5, # Relative tolerance
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These need to be documented in docstring


matvecs, formfuncs, formrhs = (
[eq.subs(time_mapper) for eq in lst] for lst in (matvecs, formfuncs, formrhs)
[eq.xreplace(time_mapper) for eq in lst] for lst in (matvecs, formfuncs, formrhs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity why this change?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think both work fine. I can't remember exactly why I changed it. Should it stay as subs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not worth another spin of CI imho -- up to you :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xreplace is what introduced a bug with custom coeff legacy API, it sometimes only replaces the first occurrence it finds instead of all so if there is more than one occurrence it's likely to break

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok, will change it back in my next PR.

calls = [call_struct_callback] + calls_set_app_ctx
def _spatial_loop_nest(self, iters, injectsolve):
spatial_body = []
for tree in retrieve_iteration_tree(iters[0]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is special about iters[0] here?

These are not set in the `PopulateMatContext` callback.
"""
return [f for f in self.fields
if isinstance(f, (ModuloDimension, TimeDimension))]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(f, (ModuloDimension, TimeDimension))]
could be replaced with
if f.is_Time
?

matvec_callback, formfunc_callback, formrhs_callback = self.make_all(
injectsolve, objs, solver_objs
)
@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: should be cached_property


# Global dimensions
args.extend(list(grid.shape)[::-1])
# No.of processors in each dimension
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: missing space in comment

- Modulo dimensions are updated in the matrix context struct at each time
step and can be accessed in the callback functions where needed.
"""
@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: should probably be cached_property, but tbh is so cheap it doesn't really matter

}
subs = {symb: true_dims[time_mapper[symb]] for symb in time_mapper}
return Uxreplace(subs).visit(body)
@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: should use cached_property

target_time = target_time.pop()
return target_time

@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

@@ -107,11 +115,6 @@ def _C_ctype(self):
@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

def symbolic_shape(self):
field_from_composites = [
FieldFromComposite('g%sm' % d.name, self.dmda.info) for d in self.dimensions]
FieldFromComposite('g%sm' % d.name, self.localinfo) for d in self.dimensions]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: fstring this?

@@ -160,9 +182,21 @@ def fields(self):

@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: probably should be cached

return [f for f in self.fields
if isinstance(f, (ModuloDimension, TimeDimension))]

@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto


def create_formrhs_body(self, injectsolve, body, solver_objs, objs):
linsolveexpr = injectsolve.expr.rhs
def _create_formrhs_body(self, injectsolve, body, solver_objs, objs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _create_formrhs_body(self, injectsolve, body, solver_objs, objs):
def _create_form_rhs_body(self, injectsolve, body, solver_objs, objs):

?

@ZoeLeibowitz ZoeLeibowitz merged commit 3b80b8b into FieldFromPointer Feb 18, 2025
29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants