Skip to content

Commit

Permalink
Pipeline-plan duplicate/remove transformation changing dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelSt98 committed Dec 19, 2024
1 parent ad1e2bc commit 3415e92
Show file tree
Hide file tree
Showing 8 changed files with 552 additions and 4 deletions.
106 changes: 104 additions & 2 deletions loki/batch/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from functools import reduce
import sys
from pathlib import Path

from loki.batch.configure import SchedulerConfig, ItemConfig
from loki.frontend import REGEX, RegexParserClass
Expand All @@ -24,6 +25,7 @@
from loki.tools import as_tuple, flatten, CaseInsensitiveDict
from loki.types import DerivedType

# pylint: disable=too-many-lines

__all__ = [
'Item', 'FileItem', 'ModuleItem', 'ProcedureItem', 'TypeDefItem',
Expand Down Expand Up @@ -137,8 +139,21 @@ def __init__(self, name, source, config=None):
self.name = name
self.source = source
self.trafo_data = {}
self.plan_data = {}
super().__init__(config)

def clone(self, **kwargs):
"""
Replicate the object with the provided overrides.
"""
if 'name' not in kwargs:
kwargs['name'] = self.name
if 'source' not in kwargs:
kwargs['source'] = self.source.clone() # self.source.clone()
if self.config is not None and 'config' not in kwargs:
kwargs['config'] = self.config
return type(self)(**kwargs)

def __repr__(self):
return f'loki.batch.{self.__class__.__name__}<{self.name}>'

Expand Down Expand Up @@ -632,10 +647,28 @@ def _dependencies(self):
Return the list of :any:`Import` nodes that constitute dependencies
for this module, filtering out imports to intrinsic modules.
"""
return tuple(
deps = tuple(
imprt for imprt in self.ir.imports
if not imprt.c_import and str(imprt.nature).lower() != 'intrinsic'
)
# potentially add dependencies due to transformations that added some
if 'additional_dependencies' in self.plan_data:
deps += self.plan_data['additional_dependencies']
# potentially remove dependencies due to transformations that removed some of those
if 'removed_dependencies' in self.plan_data:
new_deps = ()
for dep in deps:
if isinstance(dep, Import):
new_symbols = ()
for symbol in dep.symbols:
if str(symbol.name).lower() not in self.plan_data['removed_dependencies']:
new_symbols += (symbol,)
if new_symbols:
new_deps += (dep.clone(symbols=new_symbols),)
else:
new_deps += (dep,)
return new_deps
return deps

@property
def local_name(self):
Expand Down Expand Up @@ -703,7 +736,29 @@ def _dependencies(self):
import_map = self.scope.import_map
typedefs += tuple(typedef for type_name in type_names if (typedef := typedef_map.get(type_name)))
imports += tuple(imprt for type_name in type_names if (imprt := import_map.get(type_name)))
return imports + interfaces + typedefs + calls + inline_calls
deps = imports + interfaces + typedefs + calls + inline_calls
# potentially add dependencies due to transformations that added some
if 'additional_dependencies' in self.plan_data:
deps += self.plan_data['additional_dependencies']
# potentially remove dependencies due to transformations that removed some of those
if 'removed_dependencies' in self.plan_data:
new_deps = ()
for dep in deps:
if isinstance(dep, CallStatement):
if str(dep.name).lower() not in self.plan_data['removed_dependencies']:
new_deps += (dep,)
elif isinstance(dep, Import):
new_symbols = ()
for symbol in dep.symbols:
if str(symbol.name).lower() not in self.plan_data['removed_dependencies']:
new_symbols += (symbol,)
if new_symbols:
new_deps += (dep.clone(symbols=new_symbols),)
else:
# TODO: handle interfaces and inline calls as well ...
new_deps += (dep,)
return new_deps
return deps


class TypeDefItem(Item):
Expand Down Expand Up @@ -959,6 +1014,53 @@ def __contains__(self, key):
"""
return key in self.item_cache

def clone_procedure_item(self, item, suffix='', module_suffix=''):
"""
Clone and create a :any:`ProcedureItem` and additionally create a :any:`ModuleItem`
(if the passed :any:`ProcedureItem` lives within a module ) as well
as a :any:`FileItem`.
"""

path = Path(item.path)
new_path = Path(item.path).with_suffix(f'.{module_suffix}{item.path.suffix}')

local_routine_name = item.local_name
new_local_routine_name = f'{local_routine_name}_{suffix}'

mod_name = item.name.split('#')[0]
if mod_name:
new_mod_name = mod_name.replace('mod', f'{module_suffix}_mod')\
if 'mod' in mod_name else f'{mod_name}{module_suffix}'
else:
new_mod_name = ''
new_routine_name = f'{new_mod_name}#{new_local_routine_name}'

# create new source
orig_source = item.source
new_source = orig_source.clone(path=new_path)
if not mod_name:
new_source[local_routine_name].name = new_local_routine_name
else:
new_source[mod_name][local_routine_name].name = new_local_routine_name
new_source[mod_name].name = new_mod_name

# create new ModuleItem
if mod_name:
orig_mod = self.item_cache[mod_name]
self.item_cache[new_mod_name] = orig_mod.clone(name=new_mod_name, source=new_source)

# create new ProcedureItem
self.item_cache[new_routine_name] = item.clone(name=new_routine_name, source=new_source)

# create new FileItem
orig_file_item = self.item_cache[str(path)]
self.item_cache[str(new_path)] = orig_file_item.clone(name=str(new_path), source=new_source)

# return the newly created procedure/routine
if mod_name:
return new_source[new_mod_name][new_local_routine_name]
return new_source[new_local_routine_name]

def create_from_ir(self, node, scope_ir, config=None, ignore=None):
"""
Helper method to create items for definitions or dependency
Expand Down
3 changes: 2 additions & 1 deletion loki/batch/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,8 @@ def _get_definition_items(_item, sgraph_items):
item=_item, targets=_item.targets, items=_get_definition_items(_item, sgraph_items),
successors=graph.successors(_item, item_filter=item_filter),
depths=graph.depths, build_args=self.build_args,
plan_mode=proc_strategy == ProcessingStrategy.PLAN
plan_mode=proc_strategy == ProcessingStrategy.PLAN,
item_factory=self.item_factory
)

if transformation.renames_items:
Expand Down
12 changes: 12 additions & 0 deletions loki/frontend/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ def __init__(self, lines, string=None, file=None):
self.string = string
self.file = file

def clone(self, **kwargs):
"""
Replicate the object with the provided overrides.
"""
if 'lines' not in kwargs:
kwargs['lines'] = self.lines
if self.string is not None and 'string' not in kwargs:
kwargs['string'] = self.string
if self.file is not None and 'file' not in kwargs:
kwargs['file'] = self.file
return type(self)(**kwargs)

def __repr__(self):
line_end = f'-{self.lines[1]}' if self.lines[1] else ''
return f'Source<line {self.lines[0]}{line_end}>'
Expand Down
18 changes: 17 additions & 1 deletion loki/ir/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from loki.tools import flatten, as_tuple, is_iterable, truncate_string, CaseInsensitiveDict
from loki.types import DataType, BasicType, DerivedType, SymbolAttributes


__all__ = [
# Abstract base classes
'Node', 'InternalNode', 'LeafNode', 'ScopedNode',
Expand Down Expand Up @@ -462,6 +461,23 @@ def prepend(self, node):
def __repr__(self):
return 'Section::'

def recursive_clone(self, **kwargs):
"""
Clone the object and recursively clone all the elements
of the object's body.
Parameters
----------
**kwargs :
Any parameters from the constructor of the class.
Returns
-------
Object of type ``self.__class__``
The cloned object.
"""
return self.clone(body=tuple(elem.clone(**kwargs) for elem in self.body), **kwargs)


@dataclass_strict(frozen=True)
class _AssociateBase():
Expand Down
12 changes: 12 additions & 0 deletions loki/sourcefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ def __init__(self, path, ir=None, ast=None, source=None, incomplete=False, parse
self._incomplete = incomplete
self._parser_classes = parser_classes

def clone(self, **kwargs):
"""
Replicate the object with the provided overrides.
"""
if 'path' not in kwargs:
kwargs['path'] = self.path
if self.ir is not None and 'ir' not in kwargs:
kwargs['ir'] = self.ir.recursive_clone()
if self.source is not None and 'source' not in kwargs:
kwargs['source'] = self._source.clone(file=kwargs['path']) # .clone()
return type(self)(**kwargs)

@classmethod
def from_file(cls, filename, definitions=None, preprocess=False,
includes=None, defines=None, omni_includes=None,
Expand Down
1 change: 1 addition & 0 deletions loki/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@
from loki.transformations.loop_blocking import * # noqa
from loki.transformations.routine_signatures import * # noqa
from loki.transformations.parallel import * # noqa
from loki.transformations.dependency import * # noqa
111 changes: 111 additions & 0 deletions loki/transformations/dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.batch import Transformation
from loki.ir import nodes as ir, Transformer, FindNodes
from loki.tools.util import as_tuple

__all__ = ['DuplicateKernel', 'RemoveKernel']


class DuplicateKernel(Transformation):

creates_items = True

def __init__(self, kernels=None, duplicate_suffix='duplicated',
duplicate_module_suffix=None):
self.suffix = duplicate_suffix
self.module_suffix = duplicate_module_suffix or duplicate_suffix
print(f"suffix: {self.suffix}")
print(f"module_suffix: {self.module_suffix}")
self.kernels = tuple(kernel.lower() for kernel in as_tuple(kernels))

def transform_subroutine(self, routine, **kwargs):

item = kwargs.get('item', None)
item_factory = kwargs.get('item_factory', None)
if not item and 'items' in kwargs:
if kwargs['items']:
item = kwargs['items'][0]

successors = as_tuple(kwargs.get('successors'))
item.plan_data['additional_dependencies'] = ()
new_deps = {}
for child in successors:
if child.local_name.lower() in self.kernels:
new_dep = item_factory.clone_procedure_item(child, self.suffix, self.module_suffix)
new_deps[new_dep.name.lower()] = new_dep

imports = as_tuple(FindNodes(ir.Import).visit(routine.spec))
parent_imports = as_tuple(FindNodes(ir.Import).visit(routine.parent.ir)) if routine.parent is not None else ()
all_imports = imports + parent_imports
import_map = {}
for _imp in all_imports:
for symbol in _imp.symbols:
import_map[symbol] = _imp

calls = FindNodes(ir.CallStatement).visit(routine.body)
call_map = {}
for call in calls:
if str(call.name).lower() in self.kernels:
new_call_name = f'{str(call.name)}_{self.suffix}'.lower()
call_map[call] = (call, call.clone(name=new_deps[new_call_name].procedure_symbol))
if call.name in import_map:
new_import_module = \
import_map[call.name].module.upper().replace('MOD', f'{self.module_suffix.upper()}_MOD')
new_symbols = [symbol.clone(name=f"{symbol.name}_{self.suffix}")
for symbol in import_map[call.name].symbols]
new_import = ir.Import(module=new_import_module, symbols=as_tuple(new_symbols))
routine.spec.append(new_import)
routine.body = Transformer(call_map).visit(routine.body)

def plan_subroutine(self, routine, **kwargs):
item = kwargs.get('item', None)
item_factory = kwargs.get('item_factory', None)
if not item and 'items' in kwargs:
if kwargs['items']:
item = kwargs['items'][0]

successors = as_tuple(kwargs.get('successors'))
item.plan_data['additional_dependencies'] = ()
for child in successors:
if child.local_name.lower() in self.kernels:
new_dep = item_factory.clone_procedure_item(child, self.suffix, self.module_suffix)
item.plan_data['additional_dependencies'] += as_tuple(new_dep)

class RemoveKernel(Transformation):

creates_items = True

def __init__(self, kernels=None):
self.kernels = tuple(kernel.lower() for kernel in as_tuple(kernels))

def transform_subroutine(self, routine, **kwargs):
calls = FindNodes(ir.CallStatement).visit(routine.body)
call_map = {}
for call in calls:
if str(call.name).lower() in self.kernels:
call_map[call] = None
routine.body = Transformer(call_map).visit(routine.body)

def plan_subroutine(self, routine, **kwargs):
item = kwargs.get('item', None)
item_factory = kwargs.get('item_factory', None)
if not item and 'items' in kwargs:
if kwargs['items']:
item = kwargs['items'][0]

successors = as_tuple(kwargs.get('successors'))
item.plan_data['removed_dependencies'] = ()
for child in successors:
if child.local_name.lower() in self.kernels:
item.plan_data['removed_dependencies'] += (child.local_name.lower(),)
# propagate 'removed_dependencies' to corresponding module (if it exists)
module_name = item.name.split('#')[0]
if module_name:
module_item = item_factory.item_cache[item.name.split('#')[0]]
module_item.plan_data['removed_dependencies'] = item.plan_data['removed_dependencies']
Loading

0 comments on commit 3415e92

Please sign in to comment.