diff --git a/src/super_gradients/common/environment/omegaconf_utils.py b/src/super_gradients/common/environment/omegaconf_utils.py index 6eb6f93d37..f92873eb2b 100644 --- a/src/super_gradients/common/environment/omegaconf_utils.py +++ b/src/super_gradients/common/environment/omegaconf_utils.py @@ -1,5 +1,7 @@ import importlib +import operator import sys +from functools import reduce from typing import Any from omegaconf import OmegaConf, DictConfig @@ -84,6 +86,8 @@ def register_hydra_resolvers(): OmegaConf.register_new_resolver("hydra_output_dir", hydra_output_dir_resolver, replace=True) OmegaConf.register_new_resolver("class", lambda *args: get_cls(*args), replace=True) OmegaConf.register_new_resolver("add", lambda *args: sum(args), replace=True) + OmegaConf.register_new_resolver("div", lambda x, y: x / y, replace=True) + OmegaConf.register_new_resolver("mul", lambda *args: reduce(operator.mul, args[1:], args[0]), replace=True) OmegaConf.register_new_resolver("cond", lambda boolean, x, y: x if boolean else y, replace=True) OmegaConf.register_new_resolver("getitem", lambda container, key: container[key], replace=True) # get item from a container (list, dict...) OmegaConf.register_new_resolver("first", lambda lst: lst[0], replace=True) # get the first item from a list diff --git a/tests/unit_tests/hydra_resolvers_test.py b/tests/unit_tests/hydra_resolvers_test.py index 6a377f1a0d..923471fb90 100644 --- a/tests/unit_tests/hydra_resolvers_test.py +++ b/tests/unit_tests/hydra_resolvers_test.py @@ -11,8 +11,17 @@ def setUp(self) -> None: def test_add(self): conf = OmegaConf.create({"a": 1, "b": 2, "c": 3, "a_plus_b": "${add: ${a},${b}}", "a_plus_b_plus_c": "${add: ${a}, ${b}, ${c}}"}) - assert conf["a_plus_b"] == 3 - assert conf["a_plus_b_plus_c"] == 6 + self.assertEqual(conf["a_plus_b"], 3) + self.assertEqual(conf["a_plus_b_plus_c"], 6) + + def test_div(self): + conf = OmegaConf.create({"a": 1, "b": 2, "a_div_b": "${div: ${a},${b}}"}) + self.assertAlmostEqual(conf["a_div_b"], 0.5) + + def test_mul(self): + conf = OmegaConf.create({"a": 1, "b": 2, "c": 4, "a_mul_b": "${mul: ${a},${b}}", "a_mul_b_mul_c": "${mul: ${a}, ${b}, ${c}}"}) + self.assertEqual(conf["a_mul_b"], 2) + self.assertEqual(conf["a_mul_b_mul_c"], 8) def test_list(self): conf = OmegaConf.create(