From 35c9f500381385e6e3d7469b6f2b241aa288579a Mon Sep 17 00:00:00 2001 From: harrisonpim Date: Wed, 5 Mar 2025 16:48:14 +0000 Subject: [PATCH] fetch concepts from wikibase when training, with all of the optional extras --- scripts/train.py | 33 ++++++--------------------------- 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 9fc844cd..0a342390 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -10,10 +10,9 @@ from wandb.sdk.wandb_run import Run from scripts.cloud import AwsEnv, Namespace, get_s3_client, is_logged_in -from scripts.config import classifier_dir, concept_dir +from scripts.config import classifier_dir from scripts.utils import get_local_classifier_path from src.classifier import Classifier, ClassifierFactory -from src.concept import Concept from src.identifiers import WikibaseID from src.version import Version from src.wikibase import WikibaseSession @@ -241,34 +240,14 @@ def main( classifier_dir.mkdir(parents=True, exist_ok=True) - console.log(f"Loading concept {wikibase_id} from {concept_dir}") - try: - concept = Concept.load(concept_dir / f"{wikibase_id}.json") - except FileNotFoundError as e: - raise typer.BadParameter( - f"Data for {wikibase_id} not found. \n" - "If you haven't already, you should run:\n" - f" just get-concept {wikibase_id}\n" - ) from e - wikibase = WikibaseSession() - # Fetch all of its subconcepts recursively - recursive_subconcept_ids = wikibase.get_recursive_subconcept_of_relationships( - wikibase_id + concept = wikibase.get_concept( + wikibase_id, + include_recursive_subconcept_of=True, + include_recursive_has_subconcept=True, + include_labels_from_subconcepts=True, ) - subconcepts = wikibase.get_concepts(wikibase_ids=recursive_subconcept_ids) - - # fetch all of the labels and negative_labels for all of the subconcepts - # and the concept itself - all_positive_labels = set(concept.all_labels) - all_negative_labels = set(concept.negative_labels) - for subconcept in subconcepts: - all_positive_labels.update(subconcept.all_labels) - all_negative_labels.update(subconcept.negative_labels) - - concept.alternative_labels = list(all_positive_labels) - concept.negative_labels = list(all_negative_labels) # Create a classifier instance classifier = ClassifierFactory.create(concept=concept)