From 763540a4cb393e67a39971d667d803786e8237c3 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sun, 24 Mar 2024 13:52:23 +0000 Subject: [PATCH] Add test --- test/test_chronos.py | 48 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/test/test_chronos.py b/test/test_chronos.py index 85b8669..ea4f562 100644 --- a/test/test_chronos.py +++ b/test/test_chronos.py @@ -119,13 +119,13 @@ def test_tokenizer_random_data(use_eos_token: bool): assert samples.shape == (2, 10, 4) -def validate_samples(samples: torch.Tensor, shape: Tuple[int, int, int]) -> None: +def validate_tensor(samples: torch.Tensor, shape: Tuple[int, int, int]) -> None: assert isinstance(samples, torch.Tensor) assert samples.shape == shape @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) -def test_pipeline(torch_dtype: str): +def test_pipeline_predict(torch_dtype: str): pipeline = ChronosPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos-model", device_map="cpu", @@ -136,7 +136,7 @@ def test_pipeline(torch_dtype: str): # input: tensor of shape (batch_size, context_length) samples = pipeline.predict(context, num_samples=12, prediction_length=3) - validate_samples(samples, (4, 12, 3)) + validate_tensor(samples, (4, 12, 3)) with pytest.raises(ValueError): samples = pipeline.predict(context, num_samples=7, prediction_length=65) @@ -144,12 +144,12 @@ def test_pipeline(torch_dtype: str): samples = pipeline.predict( context, num_samples=7, prediction_length=65, limit_prediction_length=False ) - validate_samples(samples, (4, 7, 65)) + validate_tensor(samples, (4, 7, 65)) # input: batch_size-long list of tensors of shape (context_length,) samples = pipeline.predict(list(context), num_samples=12, prediction_length=3) - validate_samples(samples, (4, 12, 3)) + validate_tensor(samples, (4, 12, 3)) with pytest.raises(ValueError): samples = pipeline.predict(list(context), num_samples=7, prediction_length=65) @@ -160,12 +160,12 @@ def test_pipeline(torch_dtype: str): prediction_length=65, limit_prediction_length=False, ) - validate_samples(samples, (4, 7, 65)) + validate_tensor(samples, (4, 7, 65)) # input: tensor of shape (context_length,) samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3) - validate_samples(samples, (1, 12, 3)) + validate_tensor(samples, (1, 12, 3)) with pytest.raises(ValueError): samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65) @@ -176,4 +176,36 @@ def test_pipeline(torch_dtype: str): prediction_length=65, limit_prediction_length=False, ) - validate_samples(samples, (1, 7, 65)) + validate_tensor(samples, (1, 7, 65)) + + +@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) +def test_pipeline_embed(torch_dtype: str): + pipeline = ChronosPipeline.from_pretrained( + Path(__file__).parent / "dummy-chronos-model", + device_map="cpu", + torch_dtype=torch_dtype, + ) + model_context_length = pipeline.model.config.context_length + expected_embed_length = model_context_length + ( + 1 if pipeline.model.config.use_eos_token else 0 + ) + d_model = pipeline.model.model.config.d_model + context = 10 * torch.rand(size=(4, 16)) + 10 + + # input: tensor of shape (batch_size, context_length) + + embedding, scale = pipeline.embed(context) + validate_tensor(embedding, (4, expected_embed_length, d_model)) + validate_tensor(scale, (4,)) + + # input: batch_size-long list of tensors of shape (context_length,) + + embedding, scale = pipeline.embed(list(context)) + validate_tensor(embedding, (4, expected_embed_length, d_model)) + validate_tensor(scale, (4,)) + + # input: tensor of shape (context_length,) + embedding, scale = pipeline.embed(context[0, ...]) + validate_tensor(embedding, (1, expected_embed_length, d_model)) + validate_tensor(scale, (1,))