From f32a5bf6ce6bfeacdf462c03114949196bc9cce4 Mon Sep 17 00:00:00 2001 From: ksbrar Date: Fri, 16 Dec 2022 12:11:16 -0500 Subject: [PATCH 1/3] fix --- ludwig/schema/split.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ludwig/schema/split.py b/ludwig/schema/split.py index 2d5804b813c..9629e8c14ad 100644 --- a/ludwig/schema/split.py +++ b/ludwig/schema/split.py @@ -149,7 +149,6 @@ def get_split_conds(): for splitter in split_config_registry.data: splitter_cls = split_config_registry.data[splitter] other_props = schema_utils.unload_jsonschema_from_marshmallow_class(splitter_cls)["properties"] - schema_utils.remove_duplicate_fields(other_props) splitter_cond = schema_utils.create_cond( {"type": splitter}, other_props, From b137fa0486a776d970171b46991a10855ef649d9 Mon Sep 17 00:00:00 2001 From: ksbrar Date: Fri, 16 Dec 2022 12:25:13 -0500 Subject: [PATCH 2/3] add optional to remove_duplicate_fields --- ludwig/schema/split.py | 1 + ludwig/schema/utils.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ludwig/schema/split.py b/ludwig/schema/split.py index 9629e8c14ad..e7d1eb888ff 100644 --- a/ludwig/schema/split.py +++ b/ludwig/schema/split.py @@ -149,6 +149,7 @@ def get_split_conds(): for splitter in split_config_registry.data: splitter_cls = split_config_registry.data[splitter] other_props = schema_utils.unload_jsonschema_from_marshmallow_class(splitter_cls)["properties"] + schema_utils.remove_duplicate_fields(other_props, TYPE) splitter_cond = schema_utils.create_cond( {"type": splitter}, other_props, diff --git a/ludwig/schema/utils.py b/ludwig/schema/utils.py index 29b44d71607..14b20505227 100644 --- a/ludwig/schema/utils.py +++ b/ludwig/schema/utils.py @@ -4,7 +4,7 @@ from typing import Any from typing import Dict as TDict from typing import List as TList -from typing import Tuple, Type, Union +from typing import Optional, Tuple, Type, Union import yaml from marshmallow import EXCLUDE, fields, schema, validate, ValidationError @@ -103,7 +103,7 @@ def create_cond(if_pred: TDict, then_pred: TDict): @DeveloperAPI -def remove_duplicate_fields(properties: dict) -> None: +def remove_duplicate_fields(properties: dict, fields: Optional[TList[str]] = None) -> None: """Util function for removing duplicated schema elements. For example, input feature json schema mapping has a type param defined directly on the json schema, but also has a parameter defined on the schema class. We need both - @@ -114,7 +114,8 @@ def remove_duplicate_fields(properties: dict) -> None: Args: properties: Dictionary of properties generated from a Ludwig schema class """ - for key in [NAME, TYPE, COLUMN, PROC_COLUMN, ACTIVE]: # TODO: Remove col/proc_col once train metadata decoupled + duplicate_fields = [NAME, TYPE, COLUMN, PROC_COLUMN, ACTIVE] if fields is None else fields + for key in duplicate_fields: # TODO: Remove col/proc_col once train metadata decoupled if key in properties: del properties[key] From e063a4b36ca2deeca257ea77dedb2484e71921c8 Mon Sep 17 00:00:00 2001 From: ksbrar Date: Fri, 16 Dec 2022 13:05:29 -0500 Subject: [PATCH 3/3] fix test --- tests/ludwig/schema/test_validate_config_misc.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/ludwig/schema/test_validate_config_misc.py b/tests/ludwig/schema/test_validate_config_misc.py index 78fda5e8948..88b5c581eaa 100644 --- a/tests/ludwig/schema/test_validate_config_misc.py +++ b/tests/ludwig/schema/test_validate_config_misc.py @@ -364,9 +364,6 @@ def test_schema_no_duplicates(): assert field not in schema["properties"]["output_features"]["items"]["allOf"][0]["then"]["properties"] assert field not in schema["properties"]["combiner"]["allOf"][0]["then"]["properties"] assert field not in schema["properties"]["trainer"]["properties"]["optimizer"]["allOf"][0]["then"]["properties"] - assert ( - field not in schema["properties"]["preprocessing"]["properties"]["split"]["allOf"][0]["then"]["properties"] - ) assert ( field not in schema["properties"]["input_features"]["items"]["allOf"][0]["then"]["properties"]["encoder"][