diff --git a/tests/test_utils.py b/tests/test_utils.py index 0b674ea6a85c1..8203b5d2f960d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,7 +7,8 @@ import pytest -from vllm.utils import deprecate_kwargs, get_open_port, merge_async_iterators +from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs, + get_open_port, merge_async_iterators) from .utils import error_on_warning @@ -130,3 +131,61 @@ def test_get_open_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3: s3.bind(("localhost", get_open_port())) os.environ.pop("VLLM_PORT") + + +# Tests for FlexibleArgumentParser +@pytest.fixture +def parser(): + parser = FlexibleArgumentParser() + parser.add_argument('--image-input-type', + choices=['pixel_values', 'image_features']) + parser.add_argument('--model-name') + parser.add_argument('--batch-size', type=int) + parser.add_argument('--enable-feature', action='store_true') + return parser + + +def test_underscore_to_dash(parser): + args = parser.parse_args(['--image_input_type', 'pixel_values']) + assert args.image_input_type == 'pixel_values' + + +def test_mixed_usage(parser): + args = parser.parse_args([ + '--image_input_type', 'image_features', '--model-name', + 'facebook/opt-125m' + ]) + assert args.image_input_type == 'image_features' + assert args.model_name == 'facebook/opt-125m' + + +def test_with_equals_sign(parser): + args = parser.parse_args( + ['--image_input_type=pixel_values', '--model-name=facebook/opt-125m']) + assert args.image_input_type == 'pixel_values' + assert args.model_name == 'facebook/opt-125m' + + +def test_with_int_value(parser): + args = parser.parse_args(['--batch_size', '32']) + assert args.batch_size == 32 + args = parser.parse_args(['--batch-size', '32']) + assert args.batch_size == 32 + + +def test_with_bool_flag(parser): + args = parser.parse_args(['--enable_feature']) + assert args.enable_feature is True + args = parser.parse_args(['--enable-feature']) + assert args.enable_feature is True + + +def test_invalid_choice(parser): + with pytest.raises(SystemExit): + parser.parse_args(['--image_input_type', 'invalid_choice']) + + +def test_missing_required_argument(parser): + parser.add_argument('--required-arg', required=True) + with pytest.raises(SystemExit): + parser.parse_args([])