From 169810f4efb1a67c8ec8d06e54385c80155c1b82 Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Wed, 6 Nov 2024 10:29:30 +0100 Subject: [PATCH] path has been checked for nces tests --- tests/test_nces.py | 64 +++++++++++++++++++++----------------- tests/test_nces_trainer.py | 15 +++++---- 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/tests/test_nces.py b/tests/test_nces.py index 6fd8c2cc..a81a78aa 100644 --- a/tests/test_nces.py +++ b/tests/test_nces.py @@ -32,37 +32,43 @@ def seed_everything(): class TestNCES(unittest.TestCase): def test_prediction_quality_family(self): - nces = NCES(knowledge_base_path="./NCESData/family/family.owl", quality_func=F1(), num_predictions=100, - path_of_embeddings="./NCESData/family/embeddings/ConEx_entity_embeddings.csv", - learner_names=["LSTM", "GRU", "SetTransformer"]) - KB = KnowledgeBase(path=nces.knowledge_base_path) - dl_parser = DLSyntaxParser(nces.kb_namespace) - brother = dl_parser.parse('Brother') - daughter = dl_parser.parse('Daughter') - pos = set(KB.individuals(brother)).union(set(KB.individuals(daughter))) - neg = set(KB.individuals())-set(pos) - learning_problem = PosNegLPStandard(pos=pos, neg=neg) - node = list(nces.fit(learning_problem).best_predictions)[0] - print("Quality:", node.quality) - assert node.quality > 0.95 + knowledge_base_path="./NCESData/family/family.owl" + path_of_embeddings="./NCESData/family/embeddings/ConEx_entity_embeddings.csv" + if os.path.exists(knowledge_base_path) and os.path.exists(path_of_embeddings): + nces = NCES(knowledge_base_path=knowledge_base_path, quality_func=F1(), num_predictions=100, + path_of_embeddings=path_of_embeddings, + learner_names=["LSTM", "GRU", "SetTransformer"]) + KB = KnowledgeBase(path=nces.knowledge_base_path) + dl_parser = DLSyntaxParser(nces.kb_namespace) + brother = dl_parser.parse('Brother') + daughter = dl_parser.parse('Daughter') + pos = set(KB.individuals(brother)).union(set(KB.individuals(daughter))) + neg = set(KB.individuals())-set(pos) + learning_problem = PosNegLPStandard(pos=pos, neg=neg) + node = list(nces.fit(learning_problem).best_predictions)[0] + print("Quality:", node.quality) + assert node.quality > 0.95 def test_prediction_quality_mutagenesis(self): - nces = NCES(knowledge_base_path="./NCESData/mutagenesis/mutagenesis.owl", quality_func=F1(), num_predictions=100, - path_of_embeddings="./NCESData/mutagenesis/embeddings/ConEx_entity_embeddings.csv", - learner_names=["LSTM", "GRU", "SetTransformer"]) - KB = KnowledgeBase(path=nces.knowledge_base_path) - dl_parser = DLSyntaxParser(nces.kb_namespace) - exists_inbond = dl_parser.parse('∃ hasStructure.Benzene') - not_bond7 = dl_parser.parse('¬Bond-7') - pos = set(KB.individuals(exists_inbond)).intersection(set(KB.individuals(not_bond7))) - neg = sorted(set(KB.individuals())-pos) - if len(pos) > 500: - pos = set(np.random.choice(list(pos), size=min(500, len(pos)), replace=False)) - neg = set(neg[:min(1000-len(pos), len(neg))]) - learning_problem = PosNegLPStandard(pos=pos, neg=neg) - node = list(nces.fit(learning_problem).best_predictions)[0] - print("Quality:", node.quality) - assert node.quality > 0.95 + knowledge_base_path="./NCESData/mutagenesis/mutagenesis.owl" + path_of_embeddings="./NCESData/mutagenesis/embeddings/ConEx_entity_embeddings.csv" + if os.path.exists(knowledge_base_path) and os.path.exists(path_of_embeddings): + nces = NCES(knowledge_base_path=knowledge_base_path, quality_func=F1(), num_predictions=100, + path_of_embeddings=path_of_embeddings, + learner_names=["LSTM", "GRU", "SetTransformer"]) + KB = KnowledgeBase(path=nces.knowledge_base_path) + dl_parser = DLSyntaxParser(nces.kb_namespace) + exists_inbond = dl_parser.parse('∃ hasStructure.Benzene') + not_bond7 = dl_parser.parse('¬Bond-7') + pos = set(KB.individuals(exists_inbond)).intersection(set(KB.individuals(not_bond7))) + neg = sorted(set(KB.individuals())-pos) + if len(pos) > 500: + pos = set(np.random.choice(list(pos), size=min(500, len(pos)), replace=False)) + neg = set(neg[:min(1000-len(pos), len(neg))]) + learning_problem = PosNegLPStandard(pos=pos, neg=neg) + node = list(nces.fit(learning_problem).best_predictions)[0] + print("Quality:", node.quality) + assert node.quality > 0.95 if __name__ == "__main__": test = TestNCES() diff --git a/tests/test_nces_trainer.py b/tests/test_nces_trainer.py index c3a35dd4..23661a8a 100644 --- a/tests/test_nces_trainer.py +++ b/tests/test_nces_trainer.py @@ -28,12 +28,15 @@ def seed_everything(): class TestNCESTrainer(unittest.TestCase): def test_trainer_family(self): - nces = NCES(knowledge_base_path="./NCESData/family/family.owl", num_predictions=100, - path_of_embeddings="./NCESData/family/embeddings/ConEx_entity_embeddings.csv", - load_pretrained=False) - with open("./NCESData/family/training_data/Data.json") as f: - data = json.load(f) - nces.train(list(data.items())[-100:], epochs=5, learning_rate=0.001, save_model=False, record_runtime=False, storage_path=f"./NCES-{time.time()}/") + knowledge_base_path="./NCESData/family/family.owl" + path_of_embeddings="./NCESData/family/embeddings/ConEx_entity_embeddings.csv" + if os.path.exists(knowledge_base_path) and os.path.exists(path_of_embeddings): + nces = NCES(knowledge_base_path=knowledge_base_path, num_predictions=100, + path_of_embeddings=path_of_embeddings, + load_pretrained=False) + with open("./NCESData/family/training_data/Data.json") as f: + data = json.load(f) + nces.train(list(data.items())[-100:], epochs=5, learning_rate=0.001, save_model=False, record_runtime=False, storage_path=f"./NCES-{time.time()}/") if __name__ == "__main__": test = TestNCESTrainer() test.test_trainer_family()