Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve schemas loader to handle several classes in same file #1157

Merged
merged 9 commits into from
Oct 25, 2024
144 changes: 82 additions & 62 deletions core/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib
import inspect
import logging
import re
from pathlib import Path

import aenum
Expand All @@ -19,30 +20,76 @@
)

logger = logging.getLogger(__name__)
logging.getLogger().setLevel(logging.INFO)


def load_entities():
logger.info("Registering entities")
modules = dict()
for entity_file in Path(__file__).parent.glob("entities/**/*.py"):
if entity_file.stem == "__init__":
def register_schemas_types(
udgover marked this conversation as resolved.
Show resolved Hide resolved
schema_root_type: str, schema_enum: aenum, type_pattern: re.Pattern
) -> set:
"""
Register the types of schemas from implementation files
:param schema_root_type: The schema root type to work with
:param schema_enum: The schema enum to extend
:param type_pattern: The pattern to match the types in the schema implementation files
udgover marked this conversation as resolved.
Show resolved Hide resolved
"""
logger.info(f"Loading {schema_root_type} types")
modules = set()
pattern_matcher = re.compile(type_pattern)
for schema_file in Path(__file__).parent.glob(f"{schema_root_type}/**/*.py"):
if schema_file.stem == "__init__":
continue
logger.info(f"Registering entity type {entity_file.stem}")
if entity_file.parent.stem == "entities":
module_name = f"core.schemas.entities.{entity_file.stem}"
elif entity_file.parent.stem == "private":
module_name = f"core.schemas.entities.private.{entity_file.stem}"
enum_value = entity_file.stem.replace("_", "-")
if entity_file.stem not in entity.EntityType.__members__:
aenum.extend_enum(entity.EntityType, entity_file.stem, enum_value)
modules[module_name] = enum_value
entity.TYPE_MAPPING = {"entity": entity.Entity, "entities": entity.Entity}
for module_name, enum_value in modules.items():
if schema_file.parent.stem == schema_root_type:
module_name = f"core.schemas.{schema_root_type}.{schema_file.stem}"
elif schema_file.parent.stem == "private":
module_name = f"core.schemas.{schema_root_type}.private.{schema_file.stem}"
with open(schema_file, "r") as f:
content = f.read()
for schema_type in pattern_matcher.findall(content):
if schema_type not in schema_enum.__members__:
logger.debug(
f"Adding observable type <{schema_type}> to {schema_enum.__name__} enum"
)
if schema_root_type == "entities":
aenum.extend_enum(
schema_enum, schema_type, schema_type.replace("_", "-")
)
else:
aenum.extend_enum(schema_enum, schema_type, schema_type)
modules.add(module_name)
else:
logger.warning(
f"Observable type {schema_type} defined in <{module_name}> already exists"
)
return modules


def register_schema_classes(base_module, base_class, modules: set, type_mapping: dict):
"""
Register the schemas from the implementation files
:param base_module: The schema root type to work with
:param base_class: base class that the schema class should inherit from
:param modules: The modules to register
:param type_mapping: schema type mapping to update
"""
logger.info(f"Registering {base_module.__name__} classes")
for module_name in modules:
module = importlib.import_module(module_name)
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, entity.Entity):
entity.TYPE_MAPPING[enum_value] = obj
setattr(entity, obj.__name__, obj)
if issubclass(obj, base_class):
if "type" in obj.model_fields:
obs_type = obj.model_fields["type"].default.value
udgover marked this conversation as resolved.
Show resolved Hide resolved
logger.debug(
f"Registering class {obj.__name__} defining type <{obs_type}>"
)
type_mapping[obs_type] = obj
setattr(base_module, obj.__name__, obj)


def load_entities():
entity.TYPE_MAPPING = {"entity": entity.Entity, "entities": entity.Entity}
types_pattern = r"Literal\[entity.EntityType.(.+?(?=\]))"
modules = register_schemas_types("entities", entity.EntityType, types_pattern)
register_schema_classes(entity, entity.Entity, modules, entity.TYPE_MAPPING)
for key in entity.TYPE_MAPPING:
if key in ["entity", "entities"]:
continue
Expand All @@ -54,30 +101,17 @@ def load_entities():


def load_indicators():
logger.info("Registering indicators")
modules = dict()
for indicator_file in Path(__file__).parent.glob("indicators/**/*.py"):
if indicator_file.stem == "__init__":
continue
logger.info(f"Registering indicator type {indicator_file.stem}")
if indicator_file.parent.stem == "indicators":
module_name = f"core.schemas.indicators.{indicator_file.stem}"
elif indicator_file.parent.stem == "private":
module_name = f"core.schemas.indicators.private.{indicator_file.stem}"
enum_value = indicator_file.stem
if indicator_file.stem not in indicator.IndicatorType.__members__:
aenum.extend_enum(indicator.IndicatorType, indicator_file.stem, enum_value)
modules[module_name] = enum_value
indicator.TYPE_MAPPING = {
"indicator": indicator.Indicator,
"indicators": indicator.Indicator,
}
for module_name, enum_value in modules.items():
module = importlib.import_module(module_name)
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, indicator.Indicator):
indicator.TYPE_MAPPING[enum_value] = obj
setattr(indicator, obj.__name__, obj)
types_pattern = r"Literal\[indicator.IndicatorType.(.+?(?=\]))"
modules = register_schemas_types(
"indicators", indicator.IndicatorType, types_pattern
)
register_schema_classes(
indicator, indicator.Indicator, modules, indicator.TYPE_MAPPING
)
for key in indicator.TYPE_MAPPING:
if key in ["indicator", "indicators"]:
continue
Expand All @@ -89,33 +123,19 @@ def load_indicators():


def load_observables():
logger.info("Registering observables")
modules = dict()
for observable_file in Path(__file__).parent.glob("observables/**/*.py"):
if observable_file.stem == "__init__":
continue
logger.info(f"Registering observable type {observable_file.stem}")
if observable_file.parent.stem == "observables":
module_name = f"core.schemas.observables.{observable_file.stem}"
elif observable_file.parent.stem == "private":
module_name = f"core.schemas.observables.private.{observable_file.stem}"
if observable_file.stem not in observable.ObservableType.__members__:
aenum.extend_enum(
observable.ObservableType, observable_file.stem, observable_file.stem
)
modules[module_name] = observable_file.stem
if "guess" not in observable.ObservableType.__members__:
aenum.extend_enum(observable.ObservableType, "guess", "guess")
observable.TYPE_MAPPING = {
"observable": observable.Observable,
"observables": observable.Observable,
}
for module_name, enum_value in modules.items():
module = importlib.import_module(module_name)
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, observable.Observable):
observable.TYPE_MAPPING[enum_value] = obj
setattr(observable, obj.__name__, obj)
if "guess" not in observable.ObservableType.__members__:
aenum.extend_enum(observable.ObservableType, "guess", "guess")
type_pattern = r"Literal\[observable.ObservableType.(.+?(?=\]))"
modules = register_schemas_types(
"observables", observable.ObservableType, type_pattern
)
register_schema_classes(
observable, observable.Observable, modules, observable.TYPE_MAPPING
)
for key in observable.TYPE_MAPPING:
if key in ["observable", "observables"]:
continue
Expand Down
Loading