-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loading and Saving LDA Models across Python 2 and 3. #913
Changes from all commits
a4d214f
04a4634
aaae5ff
8b2cc42
c4c1289
96f8a4a
63963c0
fbd5d6d
101fed1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,7 @@ | |
from scipy.special import polygamma | ||
from six.moves import xrange | ||
import six | ||
import json | ||
|
||
# log(sum(exp(x))) that tries to avoid overflow | ||
try: | ||
|
@@ -979,7 +980,7 @@ def __getitem__(self, bow, eps=None): | |
""" | ||
return self.get_document_topics(bow, eps) | ||
|
||
def save(self, fname, ignore=['state', 'dispatcher'], *args, **kwargs): | ||
def save(self, fname, ignore=['state', 'dispatcher'], separately = None, *args, **kwargs): | ||
""" | ||
Save the model to file. | ||
|
||
|
@@ -1018,7 +1019,41 @@ def save(self, fname, ignore=['state', 'dispatcher'], *args, **kwargs): | |
ignore = list(set(['state', 'dispatcher']) | set(ignore)) | ||
else: | ||
ignore = ['state', 'dispatcher'] | ||
super(LdaModel, self).save(fname, *args, ignore=ignore, **kwargs) | ||
|
||
# make sure 'expElogbeta' and 'sstats' are ignored from the pickled object, even if | ||
# someone sets the separately list themselves. | ||
separately_explicit = ['expElogbeta', 'sstats'] | ||
# Also add 'alpha' and 'eta' to separately list if they are set 'auto' or some | ||
# array manually. | ||
if (isinstance(self.alpha, six.string_types) and self.alpha == 'auto') or len(self.alpha.shape) != 1: | ||
separately_explicit.append('alpha') | ||
if (isinstance(self.eta, six.string_types) and self.eta == 'auto') or len(self.eta.shape) != 1: | ||
separately_explicit.append('eta') | ||
# Merge separately_explicit with separately. | ||
if separately: | ||
if isinstance(separately, six.string_types): | ||
separately = [separately] | ||
separately = [e for e in separately if e] # make sure None and '' are not in the list | ||
separately = list(set(separately_explicit) | set(separately)) | ||
else: | ||
separately = separately_explicit | ||
|
||
# id2word needs to saved separately. | ||
# If id2word is not already in ignore, then saving it separately in json. | ||
id2word = None | ||
if self.id2word is not None and 'id2word' not in ignore: | ||
id2word = dict((k,v) for k,v in self.id2word.iteritems()) | ||
self.id2word = None # remove the dictionary from model | ||
super(LdaModel, self).save(fname, ignore=ignore, separately = separately, *args, **kwargs) | ||
self.id2word = id2word # restore the dictionary. | ||
|
||
# Save the dictionary separately in json. | ||
id2word_fname = utils.smart_extension(fname, '.json') | ||
try: | ||
with utils.smart_open(id2word_fname, 'w', encoding='utf-8') as fout: | ||
json.dump(id2word, fout) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better open the output as binary and write encoded utf8 to it. Actually, the |
||
except Exception as e: | ||
logging.warning("failed to save id2words dictionary in %s: %s", id2word_fname, e) | ||
|
||
@classmethod | ||
def load(cls, fname, *args, **kwargs): | ||
|
@@ -1032,6 +1067,18 @@ def load(cls, fname, *args, **kwargs): | |
""" | ||
kwargs['mmap'] = kwargs.get('mmap', None) | ||
result = super(LdaModel, cls).load(fname, *args, **kwargs) | ||
# Load the separately stored id2word dictionary saved in json. | ||
id2word_fname = utils.smart_extension(fname, '.json') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please make all files for one model in a special folder, so it is easy to keep track |
||
try: | ||
with utils.smart_open(id2word_fname, 'r') as fin: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Open file as binary, decode as necessary (if necessary). |
||
id2word = json.load(fin) | ||
if id2word is not None: | ||
result.id2word = utils.FakeDict(id2word) | ||
else: | ||
result.id2word = None | ||
except Exception as e: | ||
logging.warning("failed to load id2words from %s: %s", id2word_fname, e) | ||
|
||
state_fname = utils.smart_extension(fname, '.state') | ||
try: | ||
result.state = super(LdaModel, cls).load(state_fname, *args, **kwargs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"0": "interface", "1": "computer", "2": "human", "3": "response", "4": "time", "5": "survey", "6": "system", "7": "user", "8": "eps", "9": "trees", "10": "graph", "11": "minors"} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"0": "interface", "1": "human", "2": "computer", "3": "response", "4": "system", "5": "user", "6": "time", "7": "survey", "8": "eps", "9": "trees", "10": "graph", "11": "minors"} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -233,6 +233,19 @@ def testPersistenceWord2VecFormatCombinationWithStandardPersistence(self): | |
binary_model_with_vocab.save(testfile()) | ||
binary_model_with_vocab = word2vec.Word2Vec.load(testfile()) | ||
self.assertEqual(model.vocab['human'].count, binary_model_with_vocab.vocab['human'].count) | ||
|
||
# def testSaveModelsForPythonVersion(self): | ||
# fname = os.path.join(os.path.dirname(__file__), 'word2vecmodel_python_3_5') | ||
# model = word2vec.Word2Vec(sentences, size=10, min_count=0, seed=42, hs=1, negative=0) | ||
# model.save(fname) | ||
# logging.warning("Word2Vec model saved") | ||
|
||
def testModelCompatibilityWithPythonVersions(self): | ||
fname_model_2_7 = os.path.join(os.path.dirname(__file__), 'word2vecmodel_python_2_7') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
model_2_7 = word2vec.Word2Vec.load(fname_model_2_7) | ||
fname_model_3_5 = os.path.join(os.path.dirname(__file__), 'word2vecmodel_python_3_5') | ||
model_3_5 = word2vec.Word2Vec.load(fname_model_3_5) | ||
self.models_equal(model_2_7, model_3_5) | ||
|
||
def testLargeMmap(self): | ||
"""Test storing/loading the entire model.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -907,10 +907,12 @@ def pickle(obj, fname, protocol=2): | |
|
||
def unpickle(fname): | ||
"""Load pickled object from `fname`""" | ||
with smart_open(fname) as f: | ||
with smart_open(fname, 'rb') as f: | ||
# Because of loading from S3 load can't be used (missing readline in smart_open) | ||
return _pickle.loads(f.read()) | ||
|
||
if sys.version_info > (3,0): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PEP8: space after comma. |
||
return _pickle.load(f, encoding='latin1') | ||
else: | ||
return _pickle.loads(f.read()) | ||
|
||
def revdict(d): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP8: space after comma.