Skip to content

Commit

Permalink
feat: add a way to specify operations that are attached to modules
Browse files Browse the repository at this point in the history
  • Loading branch information
makkus committed Nov 30, 2023
1 parent ca0228c commit c367775
Show file tree
Hide file tree
Showing 25 changed files with 256 additions and 92 deletions.
3 changes: 2 additions & 1 deletion src/kiara/interfaces/cli/info/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# Mozilla Public License, version 2.0 (see LICENSE or https://www.mozilla.org/en-US/MPL/2.0/)
import rich_click as click

from kiara.interfaces.python_api import KiaraPluginInfos
from kiara.utils.cli import output_format_option, terminal_print_model
from kiara.utils.cli.exceptions import handle_exception

Expand Down Expand Up @@ -56,6 +55,8 @@ def plugin(ctx):
def list_plugins(ctx, filter_regex: str, format):
"""List installed kiara plugins."""

from kiara.interfaces.python_api import KiaraPluginInfos

api: KiaraAPI = ctx.obj.kiara_api

title = "All available plugins"
Expand Down
9 changes: 9 additions & 0 deletions src/kiara/interfaces/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@
required=False,
multiple=True,
)
@click.option(
"--print-properties",
"-p",
help="Also display the properties of the result values.",
required=False,
is_flag=True,
)
@click.option("--help", "-h", help="Show this message and exit.", is_flag=True)
@click.pass_context
@handle_exception()
Expand All @@ -62,6 +69,7 @@ def run(
output: Iterable[str],
explain: bool,
save: Iterable[str],
print_properties: bool,
help: bool,
):
"""Run a kiara operation."""
Expand Down Expand Up @@ -244,4 +252,5 @@ def run(
silent=silent,
save_results=bool(final_aliases),
aliases=final_aliases,
properties=print_properties,
)
2 changes: 1 addition & 1 deletion src/kiara/interfaces/python_api/models/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,7 @@ def extract_module_attributes(
def create_renderable(self, **config: Any) -> RenderableType:

include_config_schema = config.get("include_config_schema", True)
include_src = config.get("include_src", True)
include_src = config.get("include_src", False)
include_doc = config.get("include_doc", True)

table = Table(box=box.SIMPLE, show_header=False, padding=(0, 0, 0, 0))
Expand Down
5 changes: 3 additions & 2 deletions src/kiara/interfaces/python_api/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,9 @@ def _apply_inputs(self) -> Mapping[str, Mapping[str, Mapping[str, ChangedValue]]
break
elif step_details.status == StepStatus.INPUTS_READY:
job_config = JobConfig(
module_type=step_details.step.module_type,
module_config=step_details.step.module.config.model_dump(),
module_type=step_details.step.module.manifest.module_type,
module_config=step_details.step.module.manifest.module_config,
is_resolved=step_details.step.module.manifest.is_resolved,
inputs=step_details.inputs,
)
match = self._kiara.job_registry.find_matching_job_record(
Expand Down
7 changes: 6 additions & 1 deletion src/kiara/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
KIARA_MODEL_SCHEMA_KEY,
)
from kiara.registries.templates import TemplateRegistry
from kiara.utils import log_exception
from kiara.utils.class_loading import _default_id_func
from kiara.utils.develop import log_dev_message
from kiara.utils.hashing import KIARA_HASH_FUNCTION, compute_cid
Expand Down Expand Up @@ -133,7 +134,11 @@ def _compute_cid(self):
return

obj = self._retrieve_data_to_hash()
dag, cid = compute_cid(data=obj)
try:
dag, cid = compute_cid(data=obj)
except Exception as e:
log_exception(e)
raise e

self._cid_cache = cid
self._dag_cache = dag
Expand Down
13 changes: 10 additions & 3 deletions src/kiara/models/module/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from rich.console import RenderableType
from rich.table import Table

from kiara.exceptions import InvalidValuesException
from kiara.exceptions import InvalidValuesException, KiaraException
from kiara.models import KiaraModel
from kiara.models.module.manifest import InputsManifest

Expand Down Expand Up @@ -97,9 +97,16 @@ def create_from_module(
raise InvalidValuesException(invalid_values=invalid)

value_ids = values.get_all_value_ids()

if not module.manifest.is_resolved:
raise KiaraException(
msg="Cannot create job config from unresolved manifest."
)

return JobConfig(
module_type=module.module_type_name,
module_config=module.config.model_dump(),
module_type=module.manifest.module_type,
module_config=module.manifest.module_config,
is_resolved=module.manifest.is_resolved,
inputs=value_ids,
)

Expand Down
25 changes: 17 additions & 8 deletions src/kiara/models/module/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import orjson
from dag_cbor import IPLDKind
from multiformats import CID
from pydantic import ConfigDict, Field, PrivateAttr, field_validator
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator
from rich.console import RenderableType
from rich.syntax import Syntax

from kiara.defaults import INVALID_HASH_MARKER, NONE_VALUE_ID
from kiara.exceptions import KiaraException
from kiara.models import KiaraModel
from kiara.utils.hashing import compute_cid
from kiara.utils.json import orjson_dumps
Expand Down Expand Up @@ -53,13 +54,22 @@ class Manifest(KiaraModel):
#
# return value

@field_validator("module_config")
@classmethod
def validate_module_config(cls, value):
if isinstance(value, BaseModel):
raise ValueError(f"Invalid module config type: {type(value)}")

return value

@property
def manifest_data(self):
"""The configuration data for this module instance."""
if self._manifest_data is not None:
return self._manifest_data

mc = extract_data_to_hash_from_pipeline_config(self.module_config)

self._manifest_data = {
"module_type": self.module_type,
"module_config": mc,
Expand All @@ -72,6 +82,11 @@ def manifest_cid(self) -> CID:
if self._manifest_cid is not None:
return self._manifest_cid

if not self.is_resolved:
raise KiaraException(
msg="Cannot calculate manifest CID for unresolved manifest."
)

_, self._manifest_cid = compute_cid(self.manifest_data)
return self._manifest_cid

Expand All @@ -85,13 +100,7 @@ def manifest_data_as_json(self):

def _retrieve_data_to_hash(self) -> Any:

module_config = extract_data_to_hash_from_pipeline_config(self.module_config)
result = {
"module_type": self.module_type,
"module_config": module_config,
}

return result
return self.manifest_data

def create_renderable(self, **config: Any) -> RenderableType:
"""Create a renderable for this module configuration."""
Expand Down
15 changes: 15 additions & 0 deletions src/kiara/models/module/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,27 @@ class ManifestOperationConfig(OperationConfig):
default_factory=dict, description="The configuration for the module."
)

_manifest_cache: Union[None, Manifest] = PrivateAttr(default=None)

@field_validator("doc", mode="before")
@classmethod
def validate_doc(cls, value):
return DocumentationMetadataModel.create(value)

def retrieve_module_type(self, kiara: "Kiara") -> str:
return self.module_type

def retrieve_module_config(self, kiara: "Kiara") -> Mapping[str, Any]:
return self.module_config

def get_manifest(self) -> Manifest:

if self._manifest_cache is None:
self._manifest_cache = Manifest(
module_type=self.module_type, module_config=self.module_config
)
return self._manifest_cache


class PipelineOperationConfig(OperationConfig):

Expand Down
6 changes: 5 additions & 1 deletion src/kiara/models/values/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,10 @@ def set_value(self, field_name: str, data: Any) -> None:

ValuePedigree.model_rebuild()
ORPHAN = ValuePedigree(
kiara_id=VOID_KIARA_ID, environments={}, module_type=NO_MODULE_TYPE, inputs={}
kiara_id=VOID_KIARA_ID,
environments={},
module_type=NO_MODULE_TYPE,
inputs={},
is_resolved=True,
)
# GENESIS_PEDIGREE = None
2 changes: 2 additions & 0 deletions src/kiara/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ def __init__(
@property
def manifest(self) -> "Manifest":
if self._manifest_cache is None:
from kiara.models.module.manifest import Manifest

self._manifest_cache = Manifest(
module_type=self.module_type_name,
module_config=self.config.model_dump(),
Expand Down
128 changes: 99 additions & 29 deletions src/kiara/operations/included_core_operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
#
# Mozilla Public License, version 2.0 (see LICENSE or https://www.mozilla.org/en-US/MPL/2.0/)

from typing import TYPE_CHECKING, Iterable, Mapping, Union
from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Mapping, Type, Union

import structlog
from pydantic import Field, PrivateAttr

from kiara.exceptions import KiaraException
from kiara.models.documentation import DocumentationMetadataModel
from kiara.models.module.operation import (
ManifestOperationConfig,
Expand All @@ -22,8 +23,9 @@
from kiara.operations import OperationType

if TYPE_CHECKING:
pass
from multiformats import CID

from kiara.context import Kiara

logger = structlog.getLogger()

Expand Down Expand Up @@ -61,45 +63,113 @@ def get_operation_schema(self) -> OperationSchema:

class CustomModuleOperationType(OperationType[CustomModuleOperationDetails]):

_operation_type_name = "custom_module"
_operation_type_name: ClassVar[str] = "custom_module"

def __init__(self, kiara: "Kiara", op_type_name: str):
self._included_operations_cache: Dict[
type, List["ManifestOperationConfig"]
] = {}
self._included_operations_lookup_cache: Dict[type, Dict["CID", str]] = {}

super().__init__(kiara=kiara, op_type_name=op_type_name)

def _retrieve_included_operations(
self, module_cls: Type[KiaraModule]
) -> List["ManifestOperationConfig"]:

if self._included_operations_cache.get(module_cls, None) is not None:
return self._included_operations_cache.get(module_cls) # type: ignore

from kiara.models.module.operation import ManifestOperationConfig

this_module_type_name: str = module_cls._module_type_name # type: ignore

doc = None
cache: List[ManifestOperationConfig] = []
lookup_cache: Dict["CID", str] = {}

if not module_cls._config_cls.requires_config():
doc = DocumentationMetadataModel.from_class_doc(module_cls)
mopc = ManifestOperationConfig(module_type=this_module_type_name, doc=doc)
resolved = self._kiara.module_registry.resolve_manifest(mopc.get_manifest())
mopc._manifest_cache = resolved

cache.append(mopc)
lookup_cache[resolved.manifest_cid] = this_module_type_name

if hasattr(module_cls, "retrieve_included_operations"):
manifests = module_cls.retrieve_included_operations() # type: ignore
for op_id, op in manifests.items():
if isinstance(op, Mapping):
mtn = op.get("module_type", None)
if not mtn:
op = dict(op)
op["module_type"] = this_module_type_name
if "doc" not in op.keys():
if doc is None:
doc = DocumentationMetadataModel.from_class_doc(
module_cls
)
op["doc"] = doc
elif not mtn != this_module_type_name:
raise KiaraException(
msg=f"Included operation '{op_id}' invalid.",
reason=f"module_type must be empty or set to the name '{this_module_type_name}'.",
)
mopc = ManifestOperationConfig(**op)
elif not isinstance(op, ManifestOperationConfig):
raise KiaraException(
msg=f"Included operation '{op_id}' invalid.",
reason="Must be a Mapping or ManifestOperationConfig instance.",
)
else:
mopc = op

cache.append(mopc)
resolved = self._kiara.module_registry.resolve_manifest(
mopc.get_manifest()
)
mopc._manifest_cache = resolved
lookup_cache[resolved.manifest_cid] = op_id

self._included_operations_cache[module_cls] = cache
self._included_operations_lookup_cache[module_cls] = lookup_cache
return self._included_operations_cache[module_cls]

def retrieve_included_operation_configs(
self,
) -> Iterable[Union[Mapping, OperationConfig]]:

result = []
for name, module_cls in self._kiara.module_type_classes.items():
mod_conf = module_cls._config_cls
configs = self._retrieve_included_operations(module_cls=module_cls)
result.extend(configs)

if mod_conf.requires_config():
logger.debug(
"ignore.custom_operation",
module_type=name,
reason="config required",
)
continue
doc = DocumentationMetadataModel.from_class_doc(module_cls)
oc = ManifestOperationConfig(module_type=name, doc=doc)
result.append(oc)
return result

def check_matching_operation(
self, module: "KiaraModule"
) -> Union[CustomModuleOperationDetails, None]:
mod_conf = module.__class__._config_cls

if not mod_conf.requires_config():
is_internal = module.characteristics.is_internal
# inputs_map = {k: k for k in module.inputs_schema.keys()}
# outputs_map = {k: k for k in module.outputs_schema.keys()}
op_details: CustomModuleOperationDetails = (
CustomModuleOperationDetails.create_operation_details(
operation_id=module.module_type_name,
module_inputs_schema=module.inputs_schema,
module_outputs_schema=module.outputs_schema,
is_internal_operation=is_internal,
)

op_id: Union[str, None] = None
if not module.is_pipeline():
manifest_cid = module.manifest.manifest_cid
op_id = self._included_operations_lookup_cache[module.__class__].get(
manifest_cid
)
return op_details
else:

if not op_id:
return None

is_internal = module.characteristics.is_internal
# inputs_map = {k: k for k in module.inputs_schema.keys()}
# outputs_map = {k: k for k in module.outputs_schema.keys()}
op_details: CustomModuleOperationDetails = (
CustomModuleOperationDetails.create_operation_details(
operation_id=op_id,
module_inputs_schema=module.inputs_schema,
module_outputs_schema=module.outputs_schema,
is_internal_operation=is_internal,
)
)
return op_details
Loading

0 comments on commit c367775

Please sign in to comment.