Skip to content

Commit

Permalink
Merge branch 'develop' into owlapy-1.3.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr authored Nov 6, 2024
2 parents 37b89d1 + 169810f commit 8141fb4
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions tests/test_nces.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def seed_everything():
class TestNCES(unittest.TestCase):

def test_prediction_quality_family(self):
knowledge_base_path = "./NCESData/family/family.owl"
path_of_embeddings = "./NCESData/family/embeddings/ConEx_entity_embeddings.csv"
knowledge_base_path="./NCESData/family/family.owl"
path_of_embeddings="./NCESData/family/embeddings/ConEx_entity_embeddings.csv"
if os.path.exists(knowledge_base_path) and os.path.exists(path_of_embeddings):
nces = NCES(knowledge_base_path=knowledge_base_path, quality_func=F1(), num_predictions=100,
path_of_embeddings=path_of_embeddings,
Expand All @@ -47,15 +47,15 @@ def test_prediction_quality_family(self):
brother = dl_parser.parse('Brother')
daughter = dl_parser.parse('Daughter')
pos = set(KB.individuals(brother)).union(set(KB.individuals(daughter)))
neg = set(KB.individuals()) - set(pos)
neg = set(KB.individuals())-set(pos)
learning_problem = PosNegLPStandard(pos=pos, neg=neg)
node = list(nces.fit(learning_problem).best_predictions)[0]
print("Quality:", node.quality)
assert node.quality > 0.95

def test_prediction_quality_mutagenesis(self):
knowledge_base_path = "./NCESData/mutagenesis/mutagenesis.owl"
path_of_embeddings = "./NCESData/mutagenesis/embeddings/ConEx_entity_embeddings.csv"
knowledge_base_path="./NCESData/mutagenesis/mutagenesis.owl"
path_of_embeddings="./NCESData/mutagenesis/embeddings/ConEx_entity_embeddings.csv"
if os.path.exists(knowledge_base_path) and os.path.exists(path_of_embeddings):
nces = NCES(knowledge_base_path=knowledge_base_path, quality_func=F1(), num_predictions=100,
path_of_embeddings=path_of_embeddings,
Expand All @@ -74,7 +74,6 @@ def test_prediction_quality_mutagenesis(self):
print("Quality:", node.quality)
assert node.quality > 0.95


if __name__ == "__main__":
test = TestNCES()
test.test_prediction_quality_family()
Expand Down

0 comments on commit 8141fb4

Please sign in to comment.