Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove checking for specific pytorch warnings that don't seem to appe…
Browse files Browse the repository at this point in the history
…ar anymore during DAgger tests.
ernestum committed Dec 18, 2023
1 parent a55ff9e commit b74ad0f
Showing 1 changed file with 8 additions and 16 deletions.
24 changes: 8 additions & 16 deletions tests/scripts/test_scripts.py
Original file line number Diff line number Diff line change
@@ -266,22 +266,14 @@ def test_train_preference_comparisons_reward_named_config(tmpdir, named_configs)


def test_train_dagger_main(tmpdir):
with pytest.warns(None) as record:
run = train_imitation.train_imitation_ex.run(
command_name="dagger",
named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"],
config_updates=dict(
logging=dict(log_root=tmpdir),
demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH),
),
)
for warning in record:
# PyTorch wants writeable arrays.
# See https://github.com/HumanCompatibleAI/imitation/issues/219
assert not (
warning.category == UserWarning
and "NumPy array is not writeable" in warning.message.args[0]
)
run = train_imitation.train_imitation_ex.run(
command_name="dagger",
named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"],
config_updates=dict(
logging=dict(log_root=tmpdir),
demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH),
),
)
assert run.status == "COMPLETED"
assert isinstance(run.result, dict)

0 comments on commit b74ad0f

Please sign in to comment.