diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index c8ba344..000dc67 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -169,6 +169,7 @@ def __init__( def _input_transform( self, context: torch.Tensor, scale: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + context = context.to(dtype=torch.float32) attention_mask = ~torch.isnan(context) if scale is None: @@ -373,7 +374,7 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor: size=(max_len - len(c),), fill_value=torch.nan, device=c.device ) padded.append(torch.concat((padding, c), dim=-1)) - return torch.stack(padded) + return torch.stack(padded).to(tensors[0]) @dataclass @@ -506,6 +507,9 @@ def predict( raise ValueError(msg) warnings.warn(msg) + input_dtype = context_tensor.dtype + input_device = context_tensor.device + predictions = [] remaining = prediction_length @@ -536,7 +540,7 @@ def predict( [context_tensor, prediction.median(dim=1).values], dim=-1 ) - return torch.cat(predictions, dim=-1) + return torch.cat(predictions, dim=-1).to(dtype=input_dtype, device=input_device) @classmethod def from_pretrained(cls, *args, **kwargs): diff --git a/test/test_chronos.py b/test/test_chronos.py index a84c57d..e2b71dc 100644 --- a/test/test_chronos.py +++ b/test/test_chronos.py @@ -157,24 +157,26 @@ def test_tokenizer_random_data(use_eos_token: bool): assert samples.shape == (2, 10, 4) -def validate_tensor(samples: torch.Tensor, shape: Tuple[int, ...]) -> None: - assert isinstance(samples, torch.Tensor) - assert samples.shape == shape +def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype) -> None: + assert isinstance(a, torch.Tensor) + assert a.shape == shape + assert a.dtype == dtype -@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) -def test_pipeline_predict(torch_dtype: str): +@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16]) +def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype): pipeline = ChronosPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos-model", device_map="cpu", - torch_dtype=torch_dtype, + torch_dtype=model_dtype, ) - context = 10 * torch.rand(size=(4, 16)) + 10 + context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10 # input: tensor of shape (batch_size, context_length) samples = pipeline.predict(context, num_samples=12, prediction_length=3) - validate_tensor(samples, (4, 12, 3)) + validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype) with pytest.raises(ValueError): samples = pipeline.predict(context, num_samples=7, prediction_length=65) @@ -182,12 +184,12 @@ def test_pipeline_predict(torch_dtype: str): samples = pipeline.predict( context, num_samples=7, prediction_length=65, limit_prediction_length=False ) - validate_tensor(samples, (4, 7, 65)) + validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype) # input: batch_size-long list of tensors of shape (context_length,) samples = pipeline.predict(list(context), num_samples=12, prediction_length=3) - validate_tensor(samples, (4, 12, 3)) + validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype) with pytest.raises(ValueError): samples = pipeline.predict(list(context), num_samples=7, prediction_length=65) @@ -198,12 +200,12 @@ def test_pipeline_predict(torch_dtype: str): prediction_length=65, limit_prediction_length=False, ) - validate_tensor(samples, (4, 7, 65)) + validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype) # input: tensor of shape (context_length,) samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3) - validate_tensor(samples, (1, 12, 3)) + validate_tensor(samples, shape=(1, 12, 3), dtype=input_dtype) with pytest.raises(ValueError): samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65) @@ -214,36 +216,43 @@ def test_pipeline_predict(torch_dtype: str): prediction_length=65, limit_prediction_length=False, ) - validate_tensor(samples, (1, 7, 65)) + validate_tensor(samples, shape=(1, 7, 65), dtype=input_dtype) -@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) -def test_pipeline_embed(torch_dtype: str): +@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16]) +def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype): pipeline = ChronosPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos-model", device_map="cpu", - torch_dtype=torch_dtype, + torch_dtype=model_dtype, ) d_model = pipeline.model.model.config.d_model - context = 10 * torch.rand(size=(4, 16)) + 10 + context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10 expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0) # input: tensor of shape (batch_size, context_length) embedding, scale = pipeline.embed(context) - validate_tensor(embedding, (4, expected_embed_length, d_model)) - validate_tensor(scale, (4,)) + validate_tensor( + embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype + ) + validate_tensor(scale, shape=(4,), dtype=torch.float32) # input: batch_size-long list of tensors of shape (context_length,) embedding, scale = pipeline.embed(list(context)) - validate_tensor(embedding, (4, expected_embed_length, d_model)) - validate_tensor(scale, (4,)) + validate_tensor( + embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype + ) + validate_tensor(scale, shape=(4,), dtype=torch.float32) # input: tensor of shape (context_length,) embedding, scale = pipeline.embed(context[0, ...]) - validate_tensor(embedding, (1, expected_embed_length, d_model)) - validate_tensor(scale, (1,)) + validate_tensor( + embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype + ) + validate_tensor(scale, shape=(1,), dtype=torch.float32) @pytest.mark.parametrize("n_tokens", [10, 1000, 10000])