Skip to content

Commit

Permalink
Rework the whole thing
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Aug 2, 2024
1 parent 6a179e4 commit 57097ea
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 110 deletions.
10 changes: 4 additions & 6 deletions src/zenml/config/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,17 +234,15 @@ class NotebookSource(Source):
"""Source representing an object defined in a notebook.
Attributes:
cell_id: ID of the notebook cell in which the object is stored. This
notebook_path: Path of the notebook (relative to the source root) in
which the object is defined.
cell_id: ID of the notebook cell in which the object is defined. This
will only be set for objects which explicitly store this by calling
`zenml.utils.notebook_utils.save_notebook_cell_id()`.
replacement_module: Replacement module from which to load the source
when not running in a notebook. This will only be set for special
object for which we extract the cell code in which they're defined
into python files before running pipeline remotely.
"""

notebook_path: Optional[str] = None
cell_id: Optional[str] = None
replacement_module: Optional[str] = None
type: SourceType = SourceType.NOTEBOOK

@field_validator("type")
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/new/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,10 @@ def _run(
):
code_path = build_utils.upload_code_if_necessary()

# TODO: if we run remotely and there are steps defined in notebook
# cells, verify that we will be able to run inside the remote
# environments

deployment_request = PipelineDeploymentRequest(
user=Client().active_user.id,
workspace=Client().active_workspace.id,
Expand Down
151 changes: 56 additions & 95 deletions src/zenml/utils/notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,136 +15,87 @@

import json
import os
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import Any, Dict, Optional

from IPython import get_ipython

from zenml.config.source import NotebookSource, SourceType
from zenml.environment import Environment
from zenml.logger import get_logger
from zenml.utils import source_utils

if TYPE_CHECKING:
from zenml.config.step_configurations import Step
from zenml.models import PipelineDeploymentBase
from zenml.stack import Stack


ZENML_NOTEBOOK_CELL_ID_ATTRIBUTE_NAME = "__zenml_notebook_cell_id__"

logger = get_logger(__name__)


def get_notebook_extra_files(
deployment: "PipelineDeploymentBase", stack: "Stack"
) -> Dict[str, str]:
"""Get extra required files for running notebook code remotely.
Args:
deployment: The deployment for which to get the files.
stack: The stack on which the deployment will run.
Raises:
RuntimeError: If the cell ID for a remote step of the deployment is
not stored.
def get_active_notebook_path() -> Optional[str]:
"""Get path of the active notebook.
Returns:
A dict (file_path, file_content) of the required extra files.
Path of the active notebook.
"""
if not Environment.in_notebook():
return {}

files = {}

for step in deployment.step_configurations.values():
if step.spec.source.type == SourceType.NOTEBOOK:
assert isinstance(step.spec.source, NotebookSource)

if not step_will_run_remotely(step=step, stack=stack):
continue

cell_id = step.spec.source.cell_id
if not cell_id:
raise RuntimeError(
"Failed to extract notebook cell code because no cell ID"
"was saved for this step."
)

module_name = (
f"zenml_extracted_notebook_code_{cell_id.replace('-', '_')}"
)
filename = f"{module_name}.py"
file_content = extract_notebook_cell_code(cell_id=cell_id)

step.spec.source.replacement_module = module_name
files[filename] = file_content
return None

return files
return "test.ipynb"


def step_will_run_remotely(step: "Step", stack: "Stack") -> bool:
"""Check whether a step will run remotely.
def load_notebook(notebook_path: str) -> Dict[str, Any]:
"""Load a notebook.
Args:
step: The step to check.
stack: The stack on which the step will run.
notebook_path: Path to the notebook.
Raises:
FileNotFoundError: If no notebook exist at the path.
Returns:
Whether the step will run remotely.
Dictionary of the notebook.
"""
if step.config.step_operator:
return True
if not os.path.exists(notebook_path):
raise FileNotFoundError(
f"Notebook at path {notebook_path} does not exist."
)

if stack.orchestrator.config.is_remote:
return True
with open(notebook_path) as f:
notebook_json = json.loads(f.read())

return False
return notebook_json


def get_active_notebook_cell_id() -> str:
"""Get the ID of the currently active notebook cell.
# def load_active_notebook() -> Dict[str, Any]:
# """Load the active notebook.

Returns:
The ID of the currently active notebook cell.
"""
cell_id = get_ipython().get_parent()["metadata"]["cellId"]
return cell_id
# Raises:
# RuntimeError: If the active notebook can't be loaded.

# Returns:
# Dictionary of the notebook.
# """
# if not Environment.in_notebook():
# raise RuntimeError(
# "Can't load active notebook as you're currently not running in a "
# "notebook."
# )
# notebook_path = os.path.join(source_utils.get_source_root(), "test.ipynb")

def load_active_notebook() -> Dict[str, Any]:
"""Load the active notebook.
# if not os.path.exists(notebook_path):
# raise RuntimeError(f"Notebook at path {notebook_path} does not exist.")

Raises:
RuntimeError: If the active notebook can't be loaded.
# with open(notebook_path) as f:
# notebook_json = json.loads(f.read())

Returns:
Dictionary of the notebook.
"""
if not Environment.in_notebook():
raise RuntimeError(
"Can't load active notebook as you're currently not running in a "
"notebook."
)
notebook_path = os.path.join(source_utils.get_source_root(), "test.ipynb")
# cell_id = get_active_notebook_cell_id()

if not os.path.exists(notebook_path):
raise RuntimeError(f"Notebook at path {notebook_path} does not exist.")

with open(notebook_path) as f:
notebook_json = json.loads(f.read())
# for cell in notebook_json["cells"]:
# if cell["id"] == cell_id:
# return notebook_json

cell_id = get_active_notebook_cell_id()
# raise RuntimeError(
# f"Notebook at path {notebook_path} is not the active notebook."
# )

for cell in notebook_json["cells"]:
if cell["id"] == cell_id:
return notebook_json

raise RuntimeError(
f"Notebook at path {notebook_path} is not the active notebook."
)


def extract_notebook_cell_code(cell_id: str) -> str:
def extract_notebook_cell_code(notebook_path: str, cell_id: str) -> str:
"""Extract code from a notebook cell.
Args:
Expand All @@ -157,7 +108,7 @@ def extract_notebook_cell_code(cell_id: str) -> str:
Returns:
The cell content.
"""
notebook_json = load_active_notebook()
notebook_json = load_notebook(notebook_path)

for cell in notebook_json["cells"]:
if cell["id"] == cell_id:
Expand Down Expand Up @@ -189,6 +140,16 @@ def is_defined_in_notebook_cell(obj: Any) -> bool:
return module_name == "__main__"


def get_active_notebook_cell_id() -> str:
"""Get the ID of the currently active notebook cell.
Returns:
The ID of the currently active notebook cell.
"""
cell_id = get_ipython().get_parent()["metadata"]["cellId"]
return cell_id


def save_notebook_cell_id(obj: Any) -> None:
"""Save the notebook cell ID for an object.
Expand Down
36 changes: 27 additions & 9 deletions src/zenml/utils/source_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from zenml.constants import ENV_ZENML_CUSTOM_SOURCE_ROOT
from zenml.environment import Environment
from zenml.logger import get_logger
from zenml.utils import notebook_utils

logger = get_logger(__name__)

Expand Down Expand Up @@ -117,7 +118,7 @@ def load(source: Union[Source, str]) -> Any:
pass
else:
notebook_source = NotebookSource.model_validate(dict(source))
return _try_to_load_notebook_replacement_source(notebook_source)
return _try_to_load_notebook_source(notebook_source)
elif source.type in {SourceType.USER, SourceType.UNKNOWN}:
# Unknown source might also refer to a user file, include source
# root in python path just to be sure
Expand Down Expand Up @@ -233,12 +234,11 @@ def resolve(
# Fallback to an unknown source if we can't find the package
source_type = SourceType.UNKNOWN
elif source_type == SourceType.NOTEBOOK:
from zenml.utils import notebook_utils

return NotebookSource(
module=module_name,
attribute=attribute_name,
cell_id=notebook_utils.load_notebook_cell_id(obj),
notebook_path=notebook_utils.get_active_notebook_path(),
type=source_type,
)

Expand Down Expand Up @@ -554,8 +554,8 @@ def _load_module(
return importlib.import_module(module_name)


def _try_to_load_notebook_replacement_source(source: NotebookSource) -> Any:
"""Helper function to load a notebook source from its replacement module.
def _try_to_load_notebook_source(source: NotebookSource) -> Any:
"""Helper function to load a notebook source outside of a notebook.
Args:
source: The source to load.
Expand All @@ -566,18 +566,36 @@ def _try_to_load_notebook_replacement_source(source: NotebookSource) -> Any:
Returns:
The loaded object.
"""
if not source.replacement_module:
if not source.notebook_path or not source.cell_id:
raise RuntimeError(
f"Failed to load {source.import_path}. This object was defined in "
"a notebook and you're trying to load it outside of a notebook. "
"This is currently only supported for pipeline steps."
)

module_name = (
f"zenml_extracted_notebook_code_{source.cell_id.replace('-', '_')}"
)
# TODO: this would probably be better if we do it in a temp dir to not
# write stuff to the user dir in case they run locally
filepath = os.path.join(get_source_root(), f"{module_name}.py")

if not os.path.exists(filepath):
logger.info(
"Extracting notebook cell content to load `%s`.",
source.import_path,
)
notebook_path = os.path.join(get_source_root(), source.notebook_path)
cell_content = notebook_utils.extract_notebook_cell_code(
notebook_path=notebook_path, cell_id=source.cell_id
)

with open(filepath, "w") as f:
f.write(cell_content)

import_root = get_source_root()
try:
module = _load_module(
module_name=source.replacement_module, import_root=import_root
)
module = _load_module(module_name=module_name, import_root=import_root)
except ImportError:
raise RuntimeError(
f"Unable to load {source.import_path}. This object was defined in "
Expand Down

0 comments on commit 57097ea

Please sign in to comment.