Skip to content

Commit

Permalink
feat: in memory object flow (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
vijayvammi authored Jan 10, 2025
1 parent 5c78cb3 commit 955deef
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 35 deletions.
14 changes: 14 additions & 0 deletions examples/configs/in-memory.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
pipeline-executor:
type: local # (1)
config:
object_serialisation: false


run-log-store:
type: buffered # (2)

catalog:
type: do-nothing # (3)

secrets:
type: do-nothing # (4)
42 changes: 14 additions & 28 deletions extensions/pipeline_executor/local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging

from pydantic import Field, PrivateAttr

from extensions.pipeline_executor import GenericPipelineExecutor
from runnable import defaults
from runnable.defaults import TypeMapVariable
Expand All @@ -22,7 +24,18 @@ class LocalExecutor(GenericPipelineExecutor):
"""

service_name: str = "local"
_is_local: bool = True

object_serialisation: bool = Field(default=True)

_is_local: bool = PrivateAttr(default=True)

def execute_from_graph(
self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
):
if not self.object_serialisation:
self._context.object_serialisation = False

super().execute_from_graph(node=node, map_variable=map_variable, **kwargs)

def trigger_node_execution(
self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
Expand All @@ -47,30 +60,3 @@ def execute_node(
map_variable (dict[str, str], optional): _description_. Defaults to None.
"""
self._execute_node(node=node, map_variable=map_variable, **kwargs)

# def execute_job(self, node: TaskNode):
# """
# Set up the step log and call the execute node

# Args:
# node (BaseNode): _description_
# """

# step_log = self._context.run_log_store.create_step_log(
# node.name, node._get_step_log_name(map_variable=None)
# )

# self.add_code_identities(node=node, step_log=step_log)

# step_log.step_type = node.node_type
# step_log.status = defaults.PROCESSING
# self._context.run_log_store.add_step_log(step_log, self._context.run_id)
# self.execute_node(node=node)

# # Update the run log status
# step_log = self._context.run_log_store.get_step_log(
# node._get_step_log_name(), self._context.run_id
# )
# self._context.run_log_store.update_run_log_status(
# run_id=self._context.run_id, status=step_log.status
# )
4 changes: 3 additions & 1 deletion runnable/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
from rich.progress import Progress
Expand Down Expand Up @@ -29,6 +29,8 @@ class Context(BaseModel):
from_sdk: bool = False

run_id: str = ""
object_serialisation: bool = True
return_objects: Dict[str, Any] = {}

tag: str = ""
variables: Dict[str, str] = {}
Expand Down
17 changes: 14 additions & 3 deletions runnable/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,33 @@ class ObjectParameter(BaseModel):
@computed_field # type: ignore
@property
def description(self) -> str:
return f"Pickled object stored in catalog as: {self.value}"
if context.run_context.object_serialisation:
return f"Pickled object stored in catalog as: {self.value}"

return f"Object stored in memory as: {self.value}"

@property
def file_name(self) -> str:
return f"{self.value}{context.run_context.pickler.extension}"

def get_value(self) -> Any:
# Get the pickled object
catalog_handler = context.run_context.catalog_handler
# If there was no serialisation, return the object from the return objects
if not context.run_context.object_serialisation:
return context.run_context.return_objects[self.value]

# If the object was serialised, get it from the catalog
catalog_handler = context.run_context.catalog_handler
catalog_handler.get(name=self.file_name, run_id=context.run_context.run_id)
obj = context.run_context.pickler.load(path=self.file_name)
os.remove(self.file_name) # Remove after loading
return obj

def put_object(self, data: Any) -> None:
if not context.run_context.object_serialisation:
context.run_context.return_objects[self.value] = data
return

# If the object was serialised, put it in the catalog
context.run_context.pickler.dump(data=data, path=self.file_name)

catalog_handler = context.run_context.catalog_handler
Expand Down
4 changes: 1 addition & 3 deletions runnable/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ class BaseExecutor(ABC, BaseModel):
service_name: str = ""
service_type: str = "executor"

_is_local: bool = (
False # This is a flag to indicate whether the executor is local or not.
)
_is_local: bool = PrivateAttr(default=False)

model_config = ConfigDict(extra="forbid")

Expand Down

0 comments on commit 955deef

Please sign in to comment.