Skip to content

Commit

Permalink
add testcase for AutoPhaseDict
Browse files Browse the repository at this point in the history
  • Loading branch information
mostafaelhoushi committed Feb 26, 2022
1 parent 7ff28d4 commit bca30ef
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions examples/loop_optimizations_service/env_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_observation_spaces(env: CompilerEnv):
"ir",
"Inst2vec",
"Autophase",
"AutophaseDict",
"Programl",
"runtime",
"size",
Expand All @@ -102,8 +103,13 @@ def test_observation_spaces(env: CompilerEnv):
size_range=(0, np.iinfo(int).max),
dtype=int,
)
assert env.observation.spaces["Autophase"].space == Dict(
assert env.observation.spaces["Autophase"].space == Sequence(
name="Autophase",
size_range=(len(AUTOPHASE_FEATURE_NAMES), len(AUTOPHASE_FEATURE_NAMES)),
dtype=int,
)
assert env.observation.spaces["AutophaseDict"].space == Dict(
name="AutophaseDict",
spaces={
name: Scalar(name="", min=0, max=np.iinfo(np.int64).max, dtype=np.int64)
for name in AUTOPHASE_FEATURE_NAMES
Expand Down Expand Up @@ -182,7 +188,7 @@ def test_Step_out_of_range(env: CompilerEnv):


def test_default_ir_observation(env: CompilerEnv):
"""Test default observation space."""
"""Test default IR observation space."""
env.observation_space = "ir"
observation = env.reset()
assert len(observation) > 0
Expand All @@ -194,7 +200,7 @@ def test_default_ir_observation(env: CompilerEnv):


def test_default_inst2vec_observation(env: CompilerEnv):
"""Test default observation space."""
"""Test default inst2vec observation space."""
env.observation_space = "Inst2vec"
observation = env.reset()
assert isinstance(observation, np.ndarray)
Expand All @@ -204,7 +210,7 @@ def test_default_inst2vec_observation(env: CompilerEnv):


def test_default_autophase_observation(env: CompilerEnv):
"""Test default observation space."""
"""Test default autophase observation space."""
env.observation_space = "Autophase"
observation = env.reset()
assert isinstance(observation, np.ndarray)
Expand All @@ -213,6 +219,16 @@ def test_default_autophase_observation(env: CompilerEnv):
assert all(obs >= 0 for obs in observation.tolist())


def test_default_autophase_dict_observation(env: CompilerEnv):
"""Test default autophase dict observation space."""
env.observation_space = "AutophaseDict"
observation = env.reset()
assert isinstance(observation, dict)
assert observation.keys() == AUTOPHASE_FEATURE_NAMES
assert len(observation.values()) == len(AUTOPHASE_FEATURE_NAMES)
assert all(obs >= 0 for obs in observation.values())


def test_default_programl_observation(env: CompilerEnv):
"""Test default observation space."""
env.observation_space = "Programl"
Expand Down

0 comments on commit bca30ef

Please sign in to comment.