From c31885e9da04dbda5a5d341c81dddceb4ff06237 Mon Sep 17 00:00:00 2001 From: Juho Inkinen Date: Wed, 22 Feb 2023 13:38:59 +0200 Subject: [PATCH 1/3] Add test for training NN ensemble with 2 base projects --- tests/test_backend_nn_ensemble.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_backend_nn_ensemble.py b/tests/test_backend_nn_ensemble.py index 1941e8665..1606339ad 100644 --- a/tests/test_backend_nn_ensemble.py +++ b/tests/test_backend_nn_ensemble.py @@ -144,6 +144,30 @@ def test_nn_ensemble_train_cached(registry): assert datadir.join("nn-model.h5").size() > 0 +def test_nn_ensemble_train_two_sources(registry, tmpdir): + project = registry.get_project("dummy-en") + nn_ensemble_type = annif.backend.get_backend("nn_ensemble") + nn_ensemble = nn_ensemble_type( + backend_id="nn_ensemble", + config_params={"sources": "dummy-en,dummy-fi", "epochs": 1}, + project=project, + ) + + tmpfile = tmpdir.join("document.tsv") + tmpfile.write( + "dummy\thttp://example.org/dummy\n" + + "another\thttp://example.org/dummy\n" + + "none\thttp://example.org/none\n" * 40 + ) + document_corpus = annif.corpus.DocumentFile(str(tmpfile), project.subjects) + + nn_ensemble.train(document_corpus) + + datadir = py.path.local(project.datadir) + assert datadir.join("nn-model.h5").exists() + assert datadir.join("nn-model.h5").size() > 0 + + def test_nn_ensemble_train_and_learn_params(registry, tmpdir, capfd): project = registry.get_project("dummy-en") nn_ensemble_type = annif.backend.get_backend("nn_ensemble") From 7f762c2abe701f049461d7881e14dd6982b8e354 Mon Sep 17 00:00:00 2001 From: Juho Inkinen Date: Wed, 22 Feb 2023 10:41:02 +0200 Subject: [PATCH 2/3] Utilize batching in suggestions from base projects --- annif/backend/nn_ensemble.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/annif/backend/nn_ensemble.py b/annif/backend/nn_ensemble.py index c3caf8fcf..dcab3735c 100644 --- a/annif/backend/nn_ensemble.py +++ b/annif/backend/nn_ensemble.py @@ -4,6 +4,7 @@ import os.path import shutil +from collections import defaultdict from io import BytesIO import joblib @@ -207,18 +208,22 @@ def _corpus_to_vectors(self, corpus, seq, n_jobs): self.info("Processing training documents...") with pool_class(jobs) as pool: - for hits, subject_set in pool.imap_unordered( - psmap.suggest, corpus.documents + for hit_sets, subject_sets in pool.imap_unordered( + psmap.suggest_batch, corpus.doc_batches ): - doc_scores = [] - for project_id, p_hits in hits.items(): - vector = p_hits.as_vector(len(self.project.subjects)) - doc_scores.append( - np.sqrt(vector) * sources[project_id] * len(sources) - ) - score_vector = np.array(doc_scores, dtype=np.float32).transpose() - true_vector = subject_set.as_vector(len(self.project.subjects)) - seq.add_sample(score_vector, true_vector) + score_vectors = defaultdict(list) + for project_id, p_hit_sets in hit_sets.items(): + for doc_ind, p_hits in enumerate(p_hit_sets): + vector = p_hits.as_vector(len(self.project.subjects)) + scaled_vector = ( + np.sqrt(vector) * sources[project_id] * len(sources) + ) + score_vectors[doc_ind].append(scaled_vector) + true_vectors = [ + ss.as_vector(len(self.project.subjects)) for ss in subject_sets + ] + for sv, tv in zip(score_vectors.values(), true_vectors): + seq.add_sample(np.array(sv, dtype=np.float32).transpose(), tv) def _open_lmdb(self, cached, lmdb_map_size): lmdb_path = os.path.join(self.datadir, self.LMDB_FILE) From 38c67842d2dfb0236b12beabc64166ef7001935d Mon Sep 17 00:00:00 2001 From: Juho Inkinen Date: Wed, 8 Mar 2023 14:21:53 +0200 Subject: [PATCH 3/3] Remove now unused single-doc suggest method in parallel.py --- annif/parallel.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/annif/parallel.py b/annif/parallel.py index 8f8fe07f7..8623d186f 100644 --- a/annif/parallel.py +++ b/annif/parallel.py @@ -39,16 +39,6 @@ def __init__(self, registry, project_ids, backend_params, limit, threshold): self.limit = limit self.threshold = threshold - def suggest(self, doc): - filtered_hits = {} - for project_id in self.project_ids: - project = self.registry.get_project(project_id) - hits = project.suggest([doc.text], self.backend_params)[0] - filtered_hits[project_id] = hits.filter( - project.subjects, self.limit, self.threshold - ) - return (filtered_hits, doc.subject_set) - def suggest_batch(self, batch): filtered_hit_sets = defaultdict(list) texts, subject_sets = zip(*[(doc.text, doc.subject_set) for doc in batch])