From 89ae083b932d12924be18236e414f64242a46bc4 Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Thu, 29 Aug 2024 09:01:23 -0700 Subject: [PATCH] Rely on inheritted _make_metric as much as possible (#2718) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2718 Simplify code. Also references to super should help deobscure where this gets called/ Reviewed By: sdaulton Differential Revision: D61852855 fbshipit-source-id: 6dcef4e86e2d86ec9713ef56e9e29b7f72062aef --- ax/service/utils/instantiation.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index 640b5edb953..ba5965a64df 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -8,6 +8,7 @@ import enum from collections.abc import Sequence +from copy import deepcopy from dataclasses import dataclass from logging import Logger @@ -139,11 +140,12 @@ def _get_deserialized_metric_kwargs( ) -> dict[str, Any]: """Get metric kwargs from metric_definitions if available and deserialize if so. Deserialization is necessary because they were serialized on creation""" - metric_kwargs = (metric_definitions or {}).get(name, {}) + # deepcopy is used because of subsequent modifications to the dict + metric_kwargs = deepcopy((metric_definitions or {}).get(name, {})) metric_class = metric_kwargs.pop("metric_class", metric_class) - metric_kwargs["name"] = name + # this is necessary before deserialization because name will be required + metric_kwargs["name"] = metric_kwargs.get("name", name) metric_kwargs = metric_class.deserialize_init_args(metric_kwargs) - metric_kwargs.pop("name") return metric_kwargs @classmethod @@ -160,15 +162,15 @@ def _make_metric( "Metric names cannot contain spaces when used with AxClient. Got " f"{name!r}." ) - - return metric_class( + kwargs = cls._get_deserialized_metric_kwargs( name=name, - lower_is_better=lower_is_better, - **cls._get_deserialized_metric_kwargs( - name=name, - metric_definitions=metric_definitions, - metric_class=metric_class, - ), + metric_definitions=metric_definitions, + metric_class=metric_class, + ) + # avoid conflict is lower_is_better is specified in kwargs + kwargs["lower_is_better"] = kwargs.get("lower_is_better", lower_is_better) + return metric_class( + **kwargs, ) @staticmethod