From 9285ebdc08a48da92aa1c2b522c28b0db21ba790 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Wed, 24 Jul 2024 14:34:24 +0200 Subject: [PATCH] openlineage: make value of slots in attrs.define consistent across all OL usages Signed-off-by: Kacper Muda --- .../providers/openlineage/plugins/facets.py | 14 +++++----- airflow/providers/openlineage/utils/utils.py | 14 +++++++++- .../guides/developer.rst | 2 +- .../openlineage/plugins/test_utils.py | 28 +++++++++++++++++-- .../openlineage/utils/custom_facet_fixture.py | 2 +- 5 files changed, 47 insertions(+), 13 deletions(-) diff --git a/airflow/providers/openlineage/plugins/facets.py b/airflow/providers/openlineage/plugins/facets.py index 4c0b99d39cc48..57d6cb1a1507f 100644 --- a/airflow/providers/openlineage/plugins/facets.py +++ b/airflow/providers/openlineage/plugins/facets.py @@ -28,7 +28,7 @@ reason="To be removed in the next release. Make sure to use information from AirflowRunFacet instead.", category=AirflowProviderDeprecationWarning, ) -@define(slots=False) +@define class AirflowMappedTaskRunFacet(RunFacet): """Run facet containing information about mapped tasks.""" @@ -47,7 +47,7 @@ def from_task_instance(cls, task_instance): ) -@define(slots=True) +@define class AirflowJobFacet(JobFacet): """ Composite Airflow job facet. @@ -70,7 +70,7 @@ class AirflowJobFacet(JobFacet): tasks: dict -@define(slots=True) +@define class AirflowStateRunFacet(RunFacet): """ Airflow facet providing state information. @@ -89,7 +89,7 @@ class AirflowStateRunFacet(RunFacet): tasksState: dict[str, str] -@define(slots=False) +@define class AirflowRunFacet(RunFacet): """Composite Airflow run facet.""" @@ -100,7 +100,7 @@ class AirflowRunFacet(RunFacet): taskUuid: str -@define(slots=True) +@define class AirflowDagRunFacet(RunFacet): """Composite Airflow DAG run facet.""" @@ -108,7 +108,7 @@ class AirflowDagRunFacet(RunFacet): dagRun: dict -@define(slots=False) +@define class UnknownOperatorInstance(RedactMixin): """ Describes an unknown operator. @@ -127,7 +127,7 @@ class UnknownOperatorInstance(RedactMixin): reason="To be removed in the next release. Make sure to use information from AirflowRunFacet instead.", category=AirflowProviderDeprecationWarning, ) -@define(slots=False) +@define class UnknownOperatorAttributeRunFacet(RunFacet): """RunFacet that describes unknown operators in an Airflow DAG.""" diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py index 2e995fb7a84eb..17eb522a6d1ba 100644 --- a/airflow/providers/openlineage/utils/utils.py +++ b/airflow/providers/openlineage/utils/utils.py @@ -219,7 +219,19 @@ def _include_fields(self): setattr(self, field, getattr(self.obj, field)) self._fields.append(field) else: - for field, val in self.obj.__dict__.items(): + if hasattr(self.obj, "__dict__"): + obj_fields = self.obj.__dict__ + elif attrs.has(self.obj.__class__): # e.g. attrs.define class with slots=True has no __dict__ + obj_fields = { + field.name: getattr(self.obj, field.name) for field in attrs.fields(self.obj.__class__) + } + else: + raise ValueError( + "Cannot iterate over fields: " + f"The object of type {type(self.obj).__name__} neither has a __dict__ attribute " + "nor is defined as an attrs class." + ) + for field, val in obj_fields.items(): if field not in self._fields and field not in self.excludes and field not in self.renames: setattr(self, field, val) self._fields.append(field) diff --git a/docs/apache-airflow-providers-openlineage/guides/developer.rst b/docs/apache-airflow-providers-openlineage/guides/developer.rst index 4e9ada44c2ec4..9b56de3977cc3 100644 --- a/docs/apache-airflow-providers-openlineage/guides/developer.rst +++ b/docs/apache-airflow-providers-openlineage/guides/developer.rst @@ -481,7 +481,7 @@ Writing a custom facet function from airflow.providers.common.compat.openlineage.facet import RunFacet - @attrs.define(slots=False) + @attrs.define class MyCustomRunFacet(RunFacet): """Define a custom facet.""" diff --git a/tests/providers/openlineage/plugins/test_utils.py b/tests/providers/openlineage/plugins/test_utils.py index 8ca245d1f3d01..962429e30eb9e 100644 --- a/tests/providers/openlineage/plugins/test_utils.py +++ b/tests/providers/openlineage/plugins/test_utils.py @@ -102,6 +102,28 @@ class TestInfo(InfoJsonEncodable): casts = {"iwanttobeint": lambda x: int(x.imastring)} renames = {"_faulty_name": "goody_name"} + @define + class Test: + exclude_1: str + imastring: str + _faulty_name: str + donotcare: str + + obj = Test("val", "123", "not_funny", "abc") + + assert json.loads(json.dumps(TestInfo(obj))) == { + "iwanttobeint": 123, + "goody_name": "not_funny", + "donotcare": "abc", + } + + +def test_info_json_encodable_without_slots(): + class TestInfo(InfoJsonEncodable): + excludes = ["exclude_1", "exclude_2", "imastring"] + casts = {"iwanttobeint": lambda x: int(x.imastring)} + renames = {"_faulty_name": "goody_name"} + @define(slots=False) class Test: exclude_1: str @@ -122,7 +144,7 @@ def test_info_json_encodable_list_does_not_flatten(): class TestInfo(InfoJsonEncodable): includes = ["alist"] - @define(slots=False) + @define class Test: alist: list[str] @@ -135,7 +157,7 @@ def test_info_json_encodable_list_does_include_nonexisting(): class TestInfo(InfoJsonEncodable): includes = ["exists", "doesnotexist"] - @define(slots=False) + @define class Test: exists: str @@ -191,7 +213,7 @@ def __init__(self): self.password = "passwd" self.transparent = "123" - @define(slots=False) + @define class NestedMixined(RedactMixin): _skip_redact = ["nested_field"] password: str diff --git a/tests/providers/openlineage/utils/custom_facet_fixture.py b/tests/providers/openlineage/utils/custom_facet_fixture.py index 6b9d0edcce732..040c8c774c31f 100644 --- a/tests/providers/openlineage/utils/custom_facet_fixture.py +++ b/tests/providers/openlineage/utils/custom_facet_fixture.py @@ -26,7 +26,7 @@ from airflow.models.taskinstance import TaskInstance, TaskInstanceState -@attrs.define(slots=False) +@attrs.define class MyCustomRunFacet(RunFacet): """Define a custom run facet."""