Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline run not tracked in cached artifact version #2713

Merged
30 changes: 30 additions & 0 deletions src/zenml/models/v2/core/artifact_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ class ArtifactVersionResponseBody(WorkspaceScopedResponseBody):
title="The ID of the pipeline run that generated this artifact version."
)

pipeline_run_id: Optional[UUID] = Field(
title="The ID of the pipeline run in current context. "
"For non-cached artifact versions should be equal to "
"`producer_pipeline_run_id`, but for cached artifact "
"versions it can differ from `producer_pipeline_run_id`"
)

_convert_source = convert_source_validator("materializer", "data_type")


Expand All @@ -167,6 +174,11 @@ class ArtifactVersionResponseMetadata(WorkspaceScopedResponseMetadata):
class ArtifactVersionResponseResources(WorkspaceScopedResponseResources):
"""Class for all resource models associated with the artifact version entity."""

pipeline_run_ids: Optional[List[UUID]] = Field(
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
title="List of all pipeline run IDs, which attached "
"this artifact version to its' outputs."
)


class ArtifactVersionResponse(
WorkspaceScopedResponse[
Expand Down Expand Up @@ -242,6 +254,15 @@ def producer_pipeline_run_id(self) -> Optional[UUID]:
"""
return self.get_body().producer_pipeline_run_id

@property
def pipeline_run_id(self) -> Optional[UUID]:
"""The `pipeline_run_id` property.

Returns:
the value of the property.
"""
return self.get_body().pipeline_run_id

@property
def artifact_store_id(self) -> Optional[UUID]:
"""The `artifact_store_id` property.
Expand Down Expand Up @@ -298,6 +319,15 @@ def data_type(self) -> Source:
"""
return self.get_body().data_type

@property
def pipeline_run_ids(self) -> Optional[List[UUID]]:
"""The `all_pipeline_run_ids` property.

Returns:
the value of the property.
"""
return self.get_resources().pipeline_run_ids

# Helper methods
@property
def name(self) -> str:
Expand Down
27 changes: 25 additions & 2 deletions src/zenml/zen_stores/schemas/artifact_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
ArtifactVersionUpdate,
)
from zenml.models.v2.core.artifact import ArtifactRequest
from zenml.models.v2.core.artifact_version import (
ArtifactVersionResponseResources,
)
from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema
from zenml.zen_stores.schemas.component_schemas import StackComponentSchema
from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
Expand Down Expand Up @@ -287,13 +290,15 @@ def to_model(
self,
include_metadata: bool = False,
include_resources: bool = False,
pipeline_run_id_in_context: Optional[UUID] = None,
**kwargs: Any,
) -> ArtifactVersionResponse:
"""Convert an `ArtifactVersionSchema` to an `ArtifactVersionResponse`.

Args:
include_metadata: Whether the metadata will be filled.
include_resources: Whether the resources will be filled.
pipeline_run_id_in_context: The pipeline run id in context (e.g. StepRun context).
**kwargs: Keyword arguments to allow schema specific logic


Expand All @@ -315,11 +320,17 @@ def to_model(

producer_step_run_id, producer_pipeline_run_id = None, None
if self.output_of_step_runs:
step_run = self.output_of_step_runs[0].step_run
if step_run.status == ExecutionStatus.COMPLETED:
original_step_runs = [
sr
for sr in self.output_of_step_runs
if sr.step_run.status == ExecutionStatus.COMPLETED
]
if len(original_step_runs) == 1:
step_run = original_step_runs[0].step_run
producer_step_run_id = step_run.id
producer_pipeline_run_id = step_run.pipeline_run_id
else:
step_run = self.output_of_step_runs[0].step_run
producer_step_run_id = step_run.original_step_run_id

# Create the body of the model
Expand All @@ -335,6 +346,8 @@ def to_model(
updated=self.updated,
tags=[t.tag.to_model() for t in self.tags],
producer_pipeline_run_id=producer_pipeline_run_id,
pipeline_run_id=pipeline_run_id_in_context
or producer_pipeline_run_id,
)

# Create the metadata of the model
Expand All @@ -348,10 +361,20 @@ def to_model(
run_metadata={m.key: m.to_model() for m in self.run_metadata},
)

resources = None
if include_resources:
resources = ArtifactVersionResponseResources(
pipeline_run_ids=[
output_.step_run.pipeline_run_id
for output_ in self.output_of_step_runs
]
)

return ArtifactVersionResponse(
id=self.id,
body=body,
metadata=metadata,
resources=resources,
)

def update(
Expand Down
4 changes: 3 additions & 1 deletion src/zenml/zen_stores/schemas/step_run_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def to_model(
}

output_artifacts = {
artifact.name: artifact.artifact_version.to_model()
artifact.name: artifact.artifact_version.to_model(
pipeline_run_id_in_context=self.pipeline_run_id
)
for artifact in self.output_artifacts
}

Expand Down
4 changes: 3 additions & 1 deletion src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,7 +2462,9 @@ def get_artifact_version(
f"{artifact_version_id}: No artifact version with this ID "
f"found."
)
return artifact_version.to_model(include_metadata=hydrate)
return artifact_version.to_model(
include_metadata=hydrate, include_resources=hydrate
)

def list_artifact_versions(
self,
Expand Down
191 changes: 191 additions & 0 deletions tests/integration/functional/artifacts/test_artifacts_linage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Callable
from uuid import UUID

import pytest
from typing_extensions import Annotated

from zenml import pipeline, step
from zenml.client import Client
from zenml.enums import ModelStages
from zenml.model.model import Model


@step(enable_cache=True)
def simple_producer_step() -> Annotated[int, "trackable_artifact"]:
return 42


@step(enable_cache=False)
def keep_pipeline_alive() -> None:
pass


@pipeline
def cacheable_pipeline_which_always_run():
simple_producer_step()
keep_pipeline_alive()


@pipeline
def cacheable_pipeline_which_can_be_fully_cached():
simple_producer_step()


@pipeline
def cacheable_pipeline_where_second_step_is_cached():
simple_producer_step(id="simple_producer_step_1")
simple_producer_step(id="simple_producer_step_2")


def _validate_artifacts_state(
clean_client: Client,
pr_id: UUID,
producer_pr_id: UUID,
expected_version: int,
step_name: str = "simple_producer_step",
artifact_name: str = "trackable_artifact",
):
pr = clean_client.get_pipeline_run(pr_id)
outputs_1 = pr.steps[step_name].outputs
step = clean_client.get_run_step(pr.steps[step_name].id)
outputs_2 = step.outputs
for outputs in [outputs_1, outputs_2]:
assert len(outputs) == 1
assert int(outputs[artifact_name].version) == expected_version
# producer ID is always the original PR
assert (
outputs[artifact_name].producer_pipeline_run_id == producer_pr_id
)
# if derived from the pipeline run context - we can point to exact run
assert outputs[artifact_name].pipeline_run_id == pr_id

artifact = clean_client.get_artifact_version(artifact_name)
assert artifact.name == artifact_name
assert int(artifact.version) == expected_version
# producer ID is always the original PR
assert artifact.producer_pipeline_run_id == producer_pr_id
# cannot be derived from the context, if called just fro artifact interface
assert artifact.pipeline_run_id == producer_pr_id
# but should be listed in all runs
assert pr_id in artifact.pipeline_run_ids
assert producer_pr_id in artifact.pipeline_run_ids


# TODO: remove clean client, ones clean env for REST is available
@pytest.mark.parametrize(
"pipeline",
[
cacheable_pipeline_which_always_run,
cacheable_pipeline_which_can_be_fully_cached,
],
)
def test_that_cached_artifact_versions_are_created_properly(
pipeline: Callable, clean_client: Client
):
pr_orig = pipeline()

_validate_artifacts_state(
clean_client=clean_client,
pr_id=pr_orig.id,
producer_pr_id=pr_orig.id,
expected_version=1,
)

pr = pipeline()

pr = clean_client.get_pipeline_run(pr.id)
_validate_artifacts_state(
clean_client=clean_client,
pr_id=pr.id,
producer_pr_id=pr_orig.id,
expected_version=1, # cached artifact doesn't produce new version
)


# TODO: remove clean client, ones clean env for REST is available
def test_that_cached_artifact_versions_are_created_properly_for_second_step(
clean_client: Client,
):
pr_orig = cacheable_pipeline_where_second_step_is_cached()

_validate_artifacts_state(
clean_client=clean_client,
pr_id=pr_orig.id,
producer_pr_id=pr_orig.id,
step_name="simple_producer_step_1",
expected_version=1,
)
_validate_artifacts_state(
clean_client=clean_client,
pr_id=pr_orig.id,
producer_pr_id=pr_orig.id,
step_name="simple_producer_step_2",
expected_version=1,
)

pr = cacheable_pipeline_where_second_step_is_cached()

pr = clean_client.get_pipeline_run(pr.id)
_validate_artifacts_state(
clean_client=clean_client,
pr_id=pr.id,
producer_pr_id=pr_orig.id,
step_name="simple_producer_step_1",
expected_version=1, # cached artifact doesn't produce new version
)
_validate_artifacts_state(
clean_client=clean_client,
pr_id=pr.id,
producer_pr_id=pr_orig.id,
step_name="simple_producer_step_2",
expected_version=1, # cached artifact doesn't produce new version
)


def test_that_cached_artifact_versions_are_created_properly_for_model_version(
clean_client: Client,
):
pr_orig = cacheable_pipeline_which_always_run.with_options(
model=Model(name="foo")
)()

mv = clean_client.get_model_version("foo", ModelStages.LATEST)
assert (
mv.data_artifacts["trackable_artifact"]["1"].producer_pipeline_run_id
== pr_orig.id
)
assert (
pr_orig.id
in mv.data_artifacts["trackable_artifact"]["1"].pipeline_run_ids
)

pr = cacheable_pipeline_which_always_run.with_options(
model=Model(name="foo")
)()

mv = clean_client.get_model_version("foo", ModelStages.LATEST)
assert (
mv.data_artifacts["trackable_artifact"]["1"].producer_pipeline_run_id
== pr_orig.id
)
assert (
pr_orig.id
in mv.data_artifacts["trackable_artifact"]["1"].pipeline_run_ids
)
assert (
pr.id in mv.data_artifacts["trackable_artifact"]["1"].pipeline_run_ids
)
Loading