diff --git a/model2vec/model.py b/model2vec/model.py index ae89274..4cadf82 100644 --- a/model2vec/model.py +++ b/model2vec/model.py @@ -72,6 +72,11 @@ def __init__( else: self.normalize = self.config.get("normalize", False) + @property + def dim(self) -> int: + """Get the dimension of the model.""" + return self.embedding.weight.shape[1] + @property def device(self) -> torch.device: """Get the device of the model.""" diff --git a/tests/test_model.py b/tests/test_model.py index e371a92..f6d8810 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -160,3 +160,10 @@ def test_set_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> N assert model.config == {"normalize": False} model.normalize = True assert model.config == {"normalize": True} + + +def test_dim(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None: + """Tests the dimensionality of the model.""" + model = StaticModel(mock_vectors, mock_tokenizer, mock_config) + assert model.dim == 2 + assert model.dim == model.embedding.weight.shape[1]