diff --git a/examples/loop_optimizations_service/env_tests.py b/examples/loop_optimizations_service/env_tests.py index b217713f19..2329caa4ac 100644 --- a/examples/loop_optimizations_service/env_tests.py +++ b/examples/loop_optimizations_service/env_tests.py @@ -88,6 +88,7 @@ def test_observation_spaces(env: CompilerEnv): "ir", "Inst2vec", "Autophase", + "AutophaseDict", "Programl", "runtime", "size", @@ -102,8 +103,13 @@ def test_observation_spaces(env: CompilerEnv): size_range=(0, np.iinfo(int).max), dtype=int, ) - assert env.observation.spaces["Autophase"].space == Dict( + assert env.observation.spaces["Autophase"].space == Sequence( name="Autophase", + size_range=(len(AUTOPHASE_FEATURE_NAMES), len(AUTOPHASE_FEATURE_NAMES)), + dtype=int, + ) + assert env.observation.spaces["AutophaseDict"].space == Dict( + name="AutophaseDict", spaces={ name: Scalar(name="", min=0, max=np.iinfo(np.int64).max, dtype=np.int64) for name in AUTOPHASE_FEATURE_NAMES @@ -182,7 +188,7 @@ def test_Step_out_of_range(env: CompilerEnv): def test_default_ir_observation(env: CompilerEnv): - """Test default observation space.""" + """Test default IR observation space.""" env.observation_space = "ir" observation = env.reset() assert len(observation) > 0 @@ -194,7 +200,7 @@ def test_default_ir_observation(env: CompilerEnv): def test_default_inst2vec_observation(env: CompilerEnv): - """Test default observation space.""" + """Test default inst2vec observation space.""" env.observation_space = "Inst2vec" observation = env.reset() assert isinstance(observation, np.ndarray) @@ -204,7 +210,7 @@ def test_default_inst2vec_observation(env: CompilerEnv): def test_default_autophase_observation(env: CompilerEnv): - """Test default observation space.""" + """Test default autophase observation space.""" env.observation_space = "Autophase" observation = env.reset() assert isinstance(observation, np.ndarray) @@ -213,6 +219,16 @@ def test_default_autophase_observation(env: CompilerEnv): assert all(obs >= 0 for obs in observation.tolist()) +def test_default_autophase_dict_observation(env: CompilerEnv): + """Test default autophase dict observation space.""" + env.observation_space = "AutophaseDict" + observation = env.reset() + assert isinstance(observation, dict) + assert observation.keys() == AUTOPHASE_FEATURE_NAMES + assert len(observation.values()) == len(AUTOPHASE_FEATURE_NAMES) + assert all(obs >= 0 for obs in observation.values()) + + def test_default_programl_observation(env: CompilerEnv): """Test default observation space.""" env.observation_space = "Programl"