Skip to content

Commit

Permalink
Added more hydra resolvers (#1829)
Browse files Browse the repository at this point in the history
* Added more hydra resolvers. Need mul and div to easily scale lr and other stuff based on batch

* we <3 unittest
  • Loading branch information
NatanBagrov authored Feb 12, 2024
1 parent 56de963 commit ecad472
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/super_gradients/common/environment/omegaconf_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import importlib
import operator
import sys
from functools import reduce
from typing import Any

from omegaconf import OmegaConf, DictConfig
Expand Down Expand Up @@ -84,6 +86,8 @@ def register_hydra_resolvers():
OmegaConf.register_new_resolver("hydra_output_dir", hydra_output_dir_resolver, replace=True)
OmegaConf.register_new_resolver("class", lambda *args: get_cls(*args), replace=True)
OmegaConf.register_new_resolver("add", lambda *args: sum(args), replace=True)
OmegaConf.register_new_resolver("div", lambda x, y: x / y, replace=True)
OmegaConf.register_new_resolver("mul", lambda *args: reduce(operator.mul, args[1:], args[0]), replace=True)
OmegaConf.register_new_resolver("cond", lambda boolean, x, y: x if boolean else y, replace=True)
OmegaConf.register_new_resolver("getitem", lambda container, key: container[key], replace=True) # get item from a container (list, dict...)
OmegaConf.register_new_resolver("first", lambda lst: lst[0], replace=True) # get the first item from a list
Expand Down
13 changes: 11 additions & 2 deletions tests/unit_tests/hydra_resolvers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,17 @@ def setUp(self) -> None:

def test_add(self):
conf = OmegaConf.create({"a": 1, "b": 2, "c": 3, "a_plus_b": "${add: ${a},${b}}", "a_plus_b_plus_c": "${add: ${a}, ${b}, ${c}}"})
assert conf["a_plus_b"] == 3
assert conf["a_plus_b_plus_c"] == 6
self.assertEqual(conf["a_plus_b"], 3)
self.assertEqual(conf["a_plus_b_plus_c"], 6)

def test_div(self):
conf = OmegaConf.create({"a": 1, "b": 2, "a_div_b": "${div: ${a},${b}}"})
self.assertAlmostEqual(conf["a_div_b"], 0.5)

def test_mul(self):
conf = OmegaConf.create({"a": 1, "b": 2, "c": 4, "a_mul_b": "${mul: ${a},${b}}", "a_mul_b_mul_c": "${mul: ${a}, ${b}, ${c}}"})
self.assertEqual(conf["a_mul_b"], 2)
self.assertEqual(conf["a_mul_b_mul_c"], 8)

def test_list(self):
conf = OmegaConf.create(
Expand Down

0 comments on commit ecad472

Please sign in to comment.