-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #81 from iomega/model_serialization
Model serialization
- Loading branch information
Showing
19 changed files
with
350 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ dependencies: | |
- numpy | ||
- pip | ||
- python >=3.7 | ||
- scipy | ||
- tqdm | ||
- pip: | ||
- -e ..[dev] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,4 +9,5 @@ dependencies: | |
- numba >=0.51 | ||
- numpy | ||
- python >=3.7 | ||
- scipy | ||
- tqdm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = '0.6.0' | ||
__version__ = '0.7.0' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
""" | ||
Functions for exporting and importing trained :class:`~gensim.models.Word2Vec` model to and from disk. | ||
########################################## | ||
Functions provide the ability to export and import trained :class:`~gensim.models.Word2Vec` model to and from disk | ||
without pickling the model. The model can be stored in two files: `.json` for metadata and `.npy` for weights. | ||
""" | ||
from .model_exporting import export_model | ||
from .model_importing import Word2VecLight, import_model | ||
|
||
|
||
__all__ = [ | ||
"export_model", | ||
"import_model", | ||
"Word2VecLight" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import json | ||
import os | ||
from copy import deepcopy | ||
from typing import Union | ||
import numpy as np | ||
import scipy.sparse | ||
from gensim.models import Word2Vec | ||
|
||
|
||
def export_model(model: Word2Vec, | ||
output_model_file: Union[str, os.PathLike], | ||
output_weights_file: Union[str, os.PathLike]): | ||
""" | ||
Write a lightweight version of a :class:`~gensim.model.Word2Vec` model to disk. Such a model can be read to | ||
calculate scores but is not capable of further training. | ||
Parameters | ||
---------- | ||
model: | ||
:class:`~gensim.models.Word2Vec` trained model. | ||
output_model_file: | ||
A path of json file to save the model. | ||
output_weights_file: | ||
A path of `.npy` file to save the model's weights. | ||
""" | ||
model = deepcopy(model) | ||
keyedvectors = extract_keyedvectors(model) | ||
weights = keyedvectors.pop("vectors") | ||
keyedvectors["__weights_format"] = get_weights_format(weights) | ||
|
||
save_model(keyedvectors, output_model_file) | ||
save_weights(weights, output_weights_file) | ||
|
||
|
||
def save_weights(weights: Union[np.ndarray, scipy.sparse.csr_matrix, scipy.sparse.csc_matrix], | ||
output_weights_file: Union[str, os.PathLike]): | ||
""" | ||
Write model's weights to disk in `.npy` dense array format. If the weights are sparse, they are converted to dense | ||
prior to saving. | ||
""" | ||
if isinstance(weights, (scipy.sparse.csr_matrix, scipy.sparse.csc_matrix)): | ||
weights = weights.toarray() | ||
|
||
np.save(output_weights_file, weights, allow_pickle=False) | ||
|
||
|
||
def save_model(keyedvectors: dict, output_model_file: Union[str, os.PathLike]): | ||
"""Write model's metadata to disk in json format.""" | ||
with open(output_model_file, "w", encoding="utf-8") as f: | ||
json.dump(keyedvectors, f) | ||
|
||
|
||
def get_weights_format(weights: Union[np.ndarray, scipy.sparse.csr_matrix, scipy.sparse.csc_matrix]) -> str: | ||
""" | ||
Get the array format of the model's weights. | ||
Parameters | ||
---------- | ||
weights: | ||
Model's weights. | ||
Returns | ||
------- | ||
weights_format: | ||
Format of the model's weights. | ||
""" | ||
if isinstance(weights, np.ndarray): | ||
return "np.ndarray" | ||
if isinstance(weights, scipy.sparse.csr_matrix): | ||
return "csr_matrix" | ||
if isinstance(weights, scipy.sparse.csc_matrix): | ||
return "csc_matrix" | ||
raise NotImplementedError("The model's weights format is not supported.") | ||
|
||
|
||
def extract_keyedvectors(model: Word2Vec) -> dict: | ||
""" | ||
Extract :class:`~gensim.models.KeyedVectors` object from the model, convert it to a dictionary and | ||
remove redundant keys. | ||
Parameters | ||
---------- | ||
model: | ||
:class:`~gensim.models.Word2Vec` trained model. | ||
Returns | ||
------- | ||
keyedvectors: | ||
Dictionary representation of :class:`~gensim.models.KeyedVectors` without redundant keys. | ||
""" | ||
keyedvectors = model.wv.__dict__ | ||
keyedvectors.pop("vectors_lockf", None) | ||
keyedvectors.pop("expandos", None) | ||
return keyedvectors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import json | ||
import os | ||
from typing import Union | ||
import numpy as np | ||
import scipy.sparse | ||
from gensim.models import KeyedVectors | ||
|
||
|
||
class Word2VecLight: | ||
""" | ||
A lightweight version of :class:`~gensim.models.Word2Vec`. The objects of this class follow the interface of the | ||
original :class:`~gensim.models.Word2Vec` to the point necessary to calculate Spec2Vec scores. The model cannot be | ||
used for further training. | ||
""" | ||
|
||
def __init__(self, model: dict, weights: Union[np.ndarray, scipy.sparse.csr_matrix, scipy.sparse.csc_matrix]): | ||
""" | ||
Parameters | ||
---------- | ||
model: | ||
A dictionary containing the model's metadata. | ||
weights: | ||
A numpy array or a scipy sparse matrix containing the model's weights. | ||
""" | ||
self.wv: KeyedVectors = self._KeyedVectorsBuilder().from_dict(model).with_weights(weights).build() | ||
|
||
class _KeyedVectorsBuilder: | ||
def __init__(self): | ||
self.vector_size = None | ||
self.weights = None | ||
|
||
def build(self) -> KeyedVectors: | ||
keyed_vectors = KeyedVectors(self.vector_size) | ||
keyed_vectors.__dict__ = self.__dict__ | ||
keyed_vectors.vectors = self.weights | ||
return keyed_vectors | ||
|
||
def from_dict(self, dictionary: dict): | ||
expected_keys = {"vector_size", "__numpys", "__scipys", "__ignoreds", "__recursive_saveloads", | ||
"index_to_key", "norms", "key_to_index", "next_index", "__weights_format"} | ||
if dictionary.keys() == expected_keys: | ||
self.__dict__ = dictionary | ||
else: | ||
raise ValueError("The keys of model's dictionary representation do not match the expected keys.") | ||
return self | ||
|
||
def with_weights(self, weights: Union[np.ndarray, scipy.sparse.csr_matrix, scipy.sparse.csc_matrix]): | ||
self.weights = weights | ||
return self | ||
|
||
|
||
def import_model(model_file, weights_file) -> Word2VecLight: | ||
""" | ||
Read a lightweight version of a :class:`~gensim.models.Word2Vec` model from disk. | ||
Parameters | ||
---------- | ||
model_file: | ||
A path of json file to load the model. | ||
weights_file: | ||
A path of `.npy` file to load the model's weights. | ||
Returns | ||
------- | ||
:class:`~spec2vec.serialization.model_importing.Word2VecLight` – a lightweight version of a | ||
:class:`~gensim.models.Word2Vec` | ||
""" | ||
with open(model_file, "r", encoding="utf-8") as f: | ||
model: dict = json.load(f) | ||
|
||
weights = load_weights(weights_file, model["__weights_format"]) | ||
return Word2VecLight(model, weights) | ||
|
||
|
||
def load_weights(weights_file: Union[str, os.PathLike], | ||
weights_format: str) -> Union[np.ndarray, scipy.sparse.csr_matrix, scipy.sparse.csc_matrix]: | ||
weights: np.ndarray = np.load(weights_file, allow_pickle=False) | ||
|
||
weights_array_builder = {"csr_matrix": scipy.sparse.csr_matrix, | ||
"csc_matrix": scipy.sparse.csc_matrix, | ||
"np.ndarray": lambda x: x} | ||
weights = weights_array_builder[weights_format](weights) | ||
|
||
return weights |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from pathlib import Path | ||
import pytest | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def test_dir(request): | ||
"""Return the directory of the currently running test script.""" | ||
return Path(request.fspath).parent |
Large diffs are not rendered by default.
Oops, something went wrong.
File renamed without changes.
Binary file not shown.
Oops, something went wrong.