Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Abdul Fatir Ansari committed Mar 24, 2024
1 parent 1f27cc3 commit 763540a
Showing 1 changed file with 40 additions and 8 deletions.
48 changes: 40 additions & 8 deletions test/test_chronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ def test_tokenizer_random_data(use_eos_token: bool):
assert samples.shape == (2, 10, 4)


def validate_samples(samples: torch.Tensor, shape: Tuple[int, int, int]) -> None:
def validate_tensor(samples: torch.Tensor, shape: Tuple[int, int, int]) -> None:
assert isinstance(samples, torch.Tensor)
assert samples.shape == shape


@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
def test_pipeline(torch_dtype: str):
def test_pipeline_predict(torch_dtype: str):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
device_map="cpu",
Expand All @@ -136,20 +136,20 @@ def test_pipeline(torch_dtype: str):
# input: tensor of shape (batch_size, context_length)

samples = pipeline.predict(context, num_samples=12, prediction_length=3)
validate_samples(samples, (4, 12, 3))
validate_tensor(samples, (4, 12, 3))

with pytest.raises(ValueError):
samples = pipeline.predict(context, num_samples=7, prediction_length=65)

samples = pipeline.predict(
context, num_samples=7, prediction_length=65, limit_prediction_length=False
)
validate_samples(samples, (4, 7, 65))
validate_tensor(samples, (4, 7, 65))

# input: batch_size-long list of tensors of shape (context_length,)

samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
validate_samples(samples, (4, 12, 3))
validate_tensor(samples, (4, 12, 3))

with pytest.raises(ValueError):
samples = pipeline.predict(list(context), num_samples=7, prediction_length=65)
Expand All @@ -160,12 +160,12 @@ def test_pipeline(torch_dtype: str):
prediction_length=65,
limit_prediction_length=False,
)
validate_samples(samples, (4, 7, 65))
validate_tensor(samples, (4, 7, 65))

# input: tensor of shape (context_length,)

samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
validate_samples(samples, (1, 12, 3))
validate_tensor(samples, (1, 12, 3))

with pytest.raises(ValueError):
samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65)
Expand All @@ -176,4 +176,36 @@ def test_pipeline(torch_dtype: str):
prediction_length=65,
limit_prediction_length=False,
)
validate_samples(samples, (1, 7, 65))
validate_tensor(samples, (1, 7, 65))


@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
def test_pipeline_embed(torch_dtype: str):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
device_map="cpu",
torch_dtype=torch_dtype,
)
model_context_length = pipeline.model.config.context_length
expected_embed_length = model_context_length + (
1 if pipeline.model.config.use_eos_token else 0
)
d_model = pipeline.model.model.config.d_model
context = 10 * torch.rand(size=(4, 16)) + 10

# input: tensor of shape (batch_size, context_length)

embedding, scale = pipeline.embed(context)
validate_tensor(embedding, (4, expected_embed_length, d_model))
validate_tensor(scale, (4,))

# input: batch_size-long list of tensors of shape (context_length,)

embedding, scale = pipeline.embed(list(context))
validate_tensor(embedding, (4, expected_embed_length, d_model))
validate_tensor(scale, (4,))

# input: tensor of shape (context_length,)
embedding, scale = pipeline.embed(context[0, ...])
validate_tensor(embedding, (1, expected_embed_length, d_model))
validate_tensor(scale, (1,))

0 comments on commit 763540a

Please sign in to comment.