diff --git a/ml-agents/mlagents/trainers/tests/torch/test_utils.py b/ml-agents/mlagents/trainers/tests/torch/test_utils.py index b18ce3a894..ee00553756 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_utils.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_utils.py @@ -36,6 +36,29 @@ def test_min_visual_size(): enc.forward(vis_input) +@pytest.mark.parametrize( + "encoder_type", + [ + EncoderType.SIMPLE, + EncoderType.NATURE_CNN, + EncoderType.SIMPLE, + EncoderType.MATCH3, + ], +) +def test_invalid_visual_input_size(encoder_type): + with pytest.raises(UnityTrainerException): + obs_spec = create_observation_specs_with_shapes( + [ + ( + ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] - 1, + ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type], + 1, + ) + ] + ) + ModelUtils.create_input_processors(obs_spec, 20, encoder_type, 20, False) + + @pytest.mark.parametrize("num_visual", [0, 1, 2]) @pytest.mark.parametrize("num_vector", [0, 1, 2]) @pytest.mark.parametrize("normalize", [True, False]) diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py index 3d7ecd2836..b73bfa0bf9 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch/utils.py @@ -159,6 +159,9 @@ def get_encoder_for_obs( # VISUAL if dim_prop in ModelUtils.VALID_VISUAL_PROP: visual_encoder_class = ModelUtils.get_encoder_for_type(vis_encode_type) + ModelUtils._check_resolution_for_encoder( + shape[0], shape[1], vis_encode_type + ) return (visual_encoder_class(shape[0], shape[1], shape[2], h_size), h_size) # VECTOR if dim_prop in ModelUtils.VALID_VECTOR_PROP: