diff --git a/gensim/models/poincare.py b/gensim/models/poincare.py index e278dae9cb..42a3a60d48 100644 --- a/gensim/models/poincare.py +++ b/gensim/models/poincare.py @@ -153,6 +153,10 @@ def __init__(self, train_data, size=50, alpha=0.1, negative=10, workers=1, epsil """ self.train_data = train_data self.kv = PoincareKeyedVectors(size) + self.all_relations = [] + self.node_relations = defaultdict(set) + self._negatives_buffer = NegativesBuffer([]) + self._negatives_buffer_size = 2000 self.size = size self.train_alpha = alpha # Learning rate for training self.burn_in_alpha = burn_in_alpha # Learning rate for burn-in @@ -168,47 +172,81 @@ def __init__(self, train_data, size=50, alpha=0.1, negative=10, workers=1, epsil self._np_random = np_random.RandomState(seed) self.init_range = init_range self._loss_grad = None - self._load_relations() - self._init_embeddings() + self.build_vocab(train_data) + + def build_vocab(self, relations, update=False): + """Build the model's vocabulary from known relations. - def _load_relations(self): - """Load relations from the train data and build vocab.""" - vocab = {} - index2word = [] - all_relations = [] # List of all relation pairs - node_relations = defaultdict(set) # Mapping from node index to its related node indices + Parameters + ---------- + relations : {iterable of (str, str), :class:`gensim.models.poincare.PoincareRelations`} + Iterable of relations, e.g. a list of tuples, or a :class:`gensim.models.poincare.PoincareRelations` + instance streaming from a file. Note that the relations are treated as ordered pairs, + i.e. a relation (a, b) does not imply the opposite relation (b, a). In case the relations are symmetric, + the data should contain both relations (a, b) and (b, a). + update : bool, optional + If true, only new nodes's embeddings are initialized. + Use this when the model already has an existing vocabulary and you want to update it. + If false, all node's embeddings are initialized. + Use this when you're creating a new vocabulary from scratch. + + Examples + -------- + Train a model and update vocab for online training: + + .. sourcecode:: pycon + + >>> from gensim.models.poincare import PoincareModel + >>> + >>> # train a new model from initial data + >>> initial_relations = [('kangaroo', 'marsupial'), ('kangaroo', 'mammal')] + >>> model = PoincareModel(initial_relations, negative=1) + >>> model.train(epochs=50) + >>> + >>> # online training: update the vocabulary and continue training + >>> online_relations = [('striped_skunk', 'mammal')] + >>> model.build_vocab(online_relations, update=True) + >>> model.train(epochs=50) + + """ + old_index2word_len = len(self.kv.index2word) logger.info("loading relations from train data..") - for relation in self.train_data: + for relation in relations: if len(relation) != 2: raise ValueError('Relation pair "%s" should have exactly two items' % repr(relation)) for item in relation: - if item in vocab: - vocab[item].count += 1 + if item in self.kv.vocab: + self.kv.vocab[item].count += 1 else: - vocab[item] = Vocab(count=1, index=len(index2word)) - index2word.append(item) + self.kv.vocab[item] = Vocab(count=1, index=len(self.kv.index2word)) + self.kv.index2word.append(item) node_1, node_2 = relation - node_1_index, node_2_index = vocab[node_1].index, vocab[node_2].index - node_relations[node_1_index].add(node_2_index) + node_1_index, node_2_index = self.kv.vocab[node_1].index, self.kv.vocab[node_2].index + self.node_relations[node_1_index].add(node_2_index) relation = (node_1_index, node_2_index) - all_relations.append(relation) - logger.info("loaded %d relations from train data, %d nodes", len(all_relations), len(vocab)) - self.kv.vocab = vocab - self.kv.index2word = index2word - self.indices_set = set(range(len(index2word))) # Set of all node indices - self.indices_array = np.fromiter(range(len(index2word)), dtype=int) # Numpy array of all node indices - self.all_relations = all_relations - self.node_relations = node_relations + self.all_relations.append(relation) + logger.info("loaded %d relations from train data, %d nodes", len(self.all_relations), len(self.kv.vocab)) + self.indices_set = set(range(len(self.kv.index2word))) # Set of all node indices + self.indices_array = np.fromiter(range(len(self.kv.index2word)), dtype=int) # Numpy array of all node indices self._init_node_probabilities() - self._negatives_buffer = NegativesBuffer([]) # Buffer for negative samples, to reduce calls to sampling method - self._negatives_buffer_size = 2000 + + if not update: + self._init_embeddings() + else: + self._update_embeddings(old_index2word_len) def _init_embeddings(self): """Randomly initialize vectors for the items in the vocab.""" shape = (len(self.kv.index2word), self.size) self.kv.syn0 = self._np_random.uniform(self.init_range[0], self.init_range[1], shape).astype(self.dtype) + def _update_embeddings(self, old_index2word_len): + """Randomly initialize vectors for the items in the additional vocab.""" + shape = (len(self.kv.index2word) - old_index2word_len, self.size) + v = self._np_random.uniform(self.init_range[0], self.init_range[1], shape).astype(self.dtype) + self.kv.syn0 = np.concatenate([self.kv.syn0, v]) + def _init_node_probabilities(self): """Initialize a-priori probabilities.""" counts = np.fromiter(( @@ -830,6 +868,7 @@ def __init__(self, vector_size): super(PoincareKeyedVectors, self).__init__(vector_size) self.max_distance = 0 self.index2word = [] + self.vocab = {} @property def vectors(self): diff --git a/gensim/test/test_poincare.py b/gensim/test/test_poincare.py index c057c81bf0..9ea020da51 100644 --- a/gensim/test/test_poincare.py +++ b/gensim/test/test_poincare.py @@ -93,6 +93,16 @@ def test_persistence_separate_file(self): loaded = PoincareModel.load(testfile()) self.models_equal(model, loaded) + def test_online_learning(self): + """Tests whether additional input data is loaded correctly and completely.""" + model = PoincareModel(self.data, burn_in=0, negative=3) + self.assertEqual(len(model.kv.vocab), 7) + self.assertEqual(model.kv.vocab['kangaroo.n.01'].count, 3) + self.assertEqual(model.kv.vocab['cat.n.01'].count, 1) + model.build_vocab([('kangaroo.n.01', 'cat.n.01')], update=True) # update vocab + self.assertEqual(model.kv.vocab['kangaroo.n.01'].count, 4) + self.assertEqual(model.kv.vocab['cat.n.01'].count, 2) + def test_train_after_load(self): """Tests whether the model can be trained correctly after loading from disk.""" model = PoincareModel(self.data, burn_in=0, negative=3)