Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve schemas loader to handle several classes in same file #1157

Merged
merged 9 commits into from
Oct 25, 2024
143 changes: 56 additions & 87 deletions core/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib
import inspect
import logging
import re
from pathlib import Path

import aenum
Expand All @@ -19,111 +20,79 @@
)

logger = logging.getLogger(__name__)
logging.getLogger().setLevel(logging.INFO)


def load_entities():
logger.info("Registering entities")
modules = dict()
for entity_file in Path(__file__).parent.glob("entities/**/*.py"):
if entity_file.stem == "__init__":
def register_module(module_name, base_module):
"""
Register the classes for the schema implementation files

module_name: The module name to load
base_module: The base module to register the classes in (entity, indicator, observable)
"""
module = importlib.import_module(module_name)
module_base_name = base_module.__name__.split(".")[-1]
schema_base_class = getattr(base_module, module_base_name.capitalize())
schema_type_mapping = getattr(base_module, "TYPE_MAPPING")
schema_types = getattr(base_module, f"{module_base_name.capitalize()}Types", None)
schema_enum = getattr(base_module, f"{module_base_name.capitalize()}Type")
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, schema_base_class) and "type" in obj.model_fields:
obs_type = obj.model_fields["type"].default
logger.info(f"Registering class {obj.__name__} defining type <{obs_type}>")
aenum.extend_enum(schema_enum, obs_type, obs_type)
schema_type_mapping[obs_type] = obj
setattr(base_module, obj.__name__, obj)
if not schema_types:
schema_types = obj
else:
schema_types |= obj
setattr(base_module, f"{module_base_name.capitalize()}Types", schema_types)


def register_classes(schema_root_type, base_module):
"""
Register the classes for the schema root type

schema_root_type: The schema root type to work with (entities, indicators, observables)
base_module: The base module to register the classes in (entity, indicator, observable)
"""
module_base_name = base_module.__name__.split(".")[-1]
logger.info(f"Registering {module_base_name} classes")
for schema_file in Path(__file__).parent.glob(f"{schema_root_type}/**/*.py"):
if schema_file.stem == "__init__":
continue
logger.info(f"Registering entity type {entity_file.stem}")
if entity_file.parent.stem == "entities":
module_name = f"core.schemas.entities.{entity_file.stem}"
elif entity_file.parent.stem == "private":
module_name = f"core.schemas.entities.private.{entity_file.stem}"
enum_value = entity_file.stem.replace("_", "-")
if entity_file.stem not in entity.EntityType.__members__:
aenum.extend_enum(entity.EntityType, entity_file.stem, enum_value)
modules[module_name] = enum_value
if schema_file.parent.stem == schema_root_type:
module_name = f"core.schemas.{schema_root_type}.{schema_file.stem}"
elif schema_file.parent.stem == "private":
module_name = f"core.schemas.{schema_root_type}.private.{schema_file.stem}"
try:
register_module(module_name, base_module)
except Exception:
logger.exception(f"Failed to register classes from {module_name}")
udgover marked this conversation as resolved.
Show resolved Hide resolved


def load_entities():
entity.TYPE_MAPPING = {"entity": entity.Entity, "entities": entity.Entity}
for module_name, enum_value in modules.items():
module = importlib.import_module(module_name)
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, entity.Entity):
entity.TYPE_MAPPING[enum_value] = obj
setattr(entity, obj.__name__, obj)
for key in entity.TYPE_MAPPING:
if key in ["entity", "entities"]:
continue
cls = entity.TYPE_MAPPING[key]
if not entity.EntityTypes:
entity.EntityTypes = cls
else:
entity.EntityTypes |= cls
register_classes("entities", entity)


def load_indicators():
logger.info("Registering indicators")
modules = dict()
for indicator_file in Path(__file__).parent.glob("indicators/**/*.py"):
if indicator_file.stem == "__init__":
continue
logger.info(f"Registering indicator type {indicator_file.stem}")
if indicator_file.parent.stem == "indicators":
module_name = f"core.schemas.indicators.{indicator_file.stem}"
elif indicator_file.parent.stem == "private":
module_name = f"core.schemas.indicators.private.{indicator_file.stem}"
enum_value = indicator_file.stem
if indicator_file.stem not in indicator.IndicatorType.__members__:
aenum.extend_enum(indicator.IndicatorType, indicator_file.stem, enum_value)
modules[module_name] = enum_value
indicator.TYPE_MAPPING = {
"indicator": indicator.Indicator,
"indicators": indicator.Indicator,
}
for module_name, enum_value in modules.items():
module = importlib.import_module(module_name)
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, indicator.Indicator):
indicator.TYPE_MAPPING[enum_value] = obj
setattr(indicator, obj.__name__, obj)
for key in indicator.TYPE_MAPPING:
if key in ["indicator", "indicators"]:
continue
cls = indicator.TYPE_MAPPING[key]
if not indicator.IndicatorTypes:
indicator.IndicatorTypes = cls
else:
indicator.IndicatorTypes |= cls
register_classes("indicators", indicator)


def load_observables():
logger.info("Registering observables")
modules = dict()
for observable_file in Path(__file__).parent.glob("observables/**/*.py"):
if observable_file.stem == "__init__":
continue
logger.info(f"Registering observable type {observable_file.stem}")
if observable_file.parent.stem == "observables":
module_name = f"core.schemas.observables.{observable_file.stem}"
elif observable_file.parent.stem == "private":
module_name = f"core.schemas.observables.private.{observable_file.stem}"
if observable_file.stem not in observable.ObservableType.__members__:
aenum.extend_enum(
observable.ObservableType, observable_file.stem, observable_file.stem
)
modules[module_name] = observable_file.stem
if "guess" not in observable.ObservableType.__members__:
aenum.extend_enum(observable.ObservableType, "guess", "guess")
observable.TYPE_MAPPING = {
"observable": observable.Observable,
"observables": observable.Observable,
}
for module_name, enum_value in modules.items():
module = importlib.import_module(module_name)
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, observable.Observable):
observable.TYPE_MAPPING[enum_value] = obj
setattr(observable, obj.__name__, obj)
for key in observable.TYPE_MAPPING:
if key in ["observable", "observables"]:
continue
cls = observable.TYPE_MAPPING[key]
if not observable.ObservableTypes:
observable.ObservableTypes = cls
else:
observable.ObservableTypes |= cls
if "guess" not in observable.ObservableType.__members__:
aenum.extend_enum(observable.ObservableType, "guess", "guess")
register_classes("observables", observable)


load_observables()
Expand Down
4 changes: 2 additions & 2 deletions core/schemas/entities/attack_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class AttackPattern(entity.Entity):
_type_filter: ClassVar[str] = entity.EntityType.attack_pattern
type: Literal[entity.EntityType.attack_pattern] = entity.EntityType.attack_pattern
_type_filter: ClassVar[str] = "attack-pattern"
type: Literal["attack-pattern"] = "attack-pattern"
aliases: list[str] = []
kill_chain_phases: list[str] = []
4 changes: 2 additions & 2 deletions core/schemas/entities/campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


class Campaign(entity.Entity):
_type_filter: ClassVar[str] = entity.EntityType.campaign
type: Literal[entity.EntityType.campaign] = entity.EntityType.campaign
_type_filter: ClassVar[str] = "campaign"
type: Literal["campaign"] = "campaign"

aliases: list[str] = []
first_seen: datetime.datetime = Field(default_factory=now)
Expand Down
4 changes: 2 additions & 2 deletions core/schemas/entities/company.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@


class Company(entity.Entity):
type: Literal[entity.EntityType.company] = entity.EntityType.company
_type_filter: ClassVar[str] = entity.EntityType.company
type: Literal["company"] = "company"
_type_filter: ClassVar[str] = "company"
6 changes: 2 additions & 4 deletions core/schemas/entities/course_of_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,5 @@


class CourseOfAction(entity.Entity):
_type_filter: ClassVar[str] = entity.EntityType.course_of_action
type: Literal[entity.EntityType.course_of_action] = (
entity.EntityType.course_of_action
)
_type_filter: ClassVar[str] = "course-of-action"
type: Literal["course-of-action"] = "course-of-action"
4 changes: 2 additions & 2 deletions core/schemas/entities/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@


class Identity(entity.Entity):
_type_filter: ClassVar[str] = entity.EntityType.identity
type: Literal[entity.EntityType.identity] = entity.EntityType.identity
_type_filter: ClassVar[str] = "identity"
type: Literal["identity"] = "identity"

identity_class: str = ""
sectors: list[str] = []
Expand Down
4 changes: 2 additions & 2 deletions core/schemas/entities/intrusion_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


class IntrusionSet(entity.Entity):
_type_filter: ClassVar[str] = entity.EntityType.intrusion_set
type: Literal[entity.EntityType.intrusion_set] = entity.EntityType.intrusion_set
_type_filter: ClassVar[str] = "intrusion-set"
type: Literal["intrusion-set"] = "intrusion-set"

aliases: list[str] = []
first_seen: datetime.datetime = Field(default_factory=now)
Expand Down
4 changes: 2 additions & 2 deletions core/schemas/entities/investigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Investigation(entity.Entity):
_type_filter: ClassVar[str] = entity.EntityType.investigation
type: Literal[entity.EntityType.investigation] = entity.EntityType.investigation
_type_filter: ClassVar[str] = "investigation"
type: Literal["investigation"] = "investigation"

reference: str = ""
4 changes: 2 additions & 2 deletions core/schemas/entities/malware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@


class Malware(entity.Entity):
_type_filter: ClassVar[str] = entity.EntityType.malware
type: Literal[entity.EntityType.malware] = entity.EntityType.malware
_type_filter: ClassVar[str] = "malware"
type: Literal["malware"] = "malware"

kill_chain_phases: list[str] = []
aliases: list[str] = []
Expand Down
4 changes: 2 additions & 2 deletions core/schemas/entities/note.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@


class Note(entity.Entity):
type: Literal[entity.EntityType.note] = entity.EntityType.note
_type_filter: ClassVar[str] = entity.EntityType.note
type: Literal["note"] = "note"
_type_filter: ClassVar[str] = "note"
4 changes: 2 additions & 2 deletions core/schemas/entities/phone.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@


class Phone(entity.Entity):
type: Literal[entity.EntityType.phone] = entity.EntityType.phone
_type_filter: ClassVar[str] = entity.EntityType.phone
type: Literal["phone"] = "phone"
_type_filter: ClassVar[str] = "phone"
4 changes: 2 additions & 2 deletions core/schemas/entities/threat_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


class ThreatActor(entity.Entity):
_type_filter: ClassVar[str] = entity.EntityType.threat_actor
type: Literal[entity.EntityType.threat_actor] = entity.EntityType.threat_actor
_type_filter: ClassVar[str] = "threat-actor"
type: Literal["threat-actor"] = "threat-actor"

threat_actor_types: list[str] = []
aliases: list[str] = []
Expand Down
4 changes: 2 additions & 2 deletions core/schemas/entities/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@


class Tool(entity.Entity):
_type_filter: ClassVar[str] = entity.EntityType.tool
type: Literal[entity.EntityType.tool] = entity.EntityType.tool
_type_filter: ClassVar[str] = "tool"
type: Literal["tool"] = "tool"

aliases: list[str] = []
kill_chain_phases: list[str] = []
Expand Down
4 changes: 2 additions & 2 deletions core/schemas/entities/vulnerability.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class Vulnerability(entity.Entity):
medium, high, critical.
"""

_type_filter: ClassVar[str] = entity.EntityType.vulnerability
type: Literal[entity.EntityType.vulnerability] = entity.EntityType.vulnerability
_type_filter: ClassVar[str] = "vulnerability"
type: Literal["vulnerability"] = "vulnerability"

title: str = ""
base_score: float = Field(ge=0.0, le=10.0, default=0.0)
Expand Down
6 changes: 2 additions & 4 deletions core/schemas/indicators/forensicartifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ class ForensicArtifact(indicator.Indicator):
As defined in https://github.com/ForensicArtifacts/artifacts
"""

_type_filter: ClassVar[str] = indicator.IndicatorType.forensicartifact
type: Literal[indicator.IndicatorType.forensicartifact] = (
indicator.IndicatorType.forensicartifact
)
_type_filter: ClassVar[str] = "forensicartifact"
type: Literal["forensicartifact"] = "forensicartifact"

sources: list[dict] = []
aliases: list[str] = []
Expand Down
4 changes: 2 additions & 2 deletions core/schemas/indicators/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
class Query(indicator.Indicator):
"""Represents a query that can be sent to another system."""

_type_filter: ClassVar[str] = indicator.IndicatorType.query
type: Literal[indicator.IndicatorType.query] = indicator.IndicatorType.query
_type_filter: ClassVar[str] = "query"
type: Literal["query"] = "query"

query_type: str
target_systems: list[str] = []
Expand Down
4 changes: 2 additions & 2 deletions core/schemas/indicators/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@


class Regex(indicator.Indicator):
_type_filter: ClassVar[str] = indicator.IndicatorType.regex
_type_filter: ClassVar[str] = "regex"
_compiled_pattern: re.Pattern | None = PrivateAttr(None)
type: Literal[indicator.IndicatorType.regex] = indicator.IndicatorType.regex
type: Literal["regex"] = "regex"

@property
def compiled_pattern(self):
Expand Down
4 changes: 2 additions & 2 deletions core/schemas/indicators/sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class Sigma(indicator.Indicator):
Parsing and matching is yet TODO.
"""

_type_filter: ClassVar[str] = indicator.IndicatorType.sigma
type: Literal[indicator.IndicatorType.sigma] = indicator.IndicatorType.sigma
_type_filter: ClassVar[str] = "sigma"
type: Literal["sigma"] = "sigma"

def match(self, value: str) -> indicator.IndicatorMatch | None:
raise NotImplementedError
4 changes: 2 additions & 2 deletions core/schemas/indicators/suricata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class Suricata(indicator.Indicator):
Parsing and matching is yet TODO.
"""

_type_filter: ClassVar[str] = indicator.IndicatorType.suricata
type: Literal[indicator.IndicatorType.suricata] = indicator.IndicatorType.suricata
_type_filter: ClassVar[str] = "suricata"
type: Literal["suricata"] = "suricata"
sid: int = 0
metadata: List[str] = []
references: List[str] = []
Expand Down
4 changes: 2 additions & 2 deletions core/schemas/indicators/yara.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class Yara(indicator.Indicator):
Parsing and matching is yet TODO.
"""

_type_filter: ClassVar[str] = indicator.IndicatorType.yara
type: Literal[indicator.IndicatorType.yara] = indicator.IndicatorType.yara
_type_filter: ClassVar[str] = "yara"
type: Literal["yara"] = "yara"

def match(self, value: str) -> indicator.IndicatorMatch | None:
raise NotImplementedError
2 changes: 1 addition & 1 deletion core/schemas/observables/asn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@


class ASN(observable.Observable):
type: Literal[observable.ObservableType.asn] = observable.ObservableType.asn
type: Literal["asn"] = "asn"
country: str | None = None
description: str | None = None
2 changes: 1 addition & 1 deletion core/schemas/observables/bic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class BIC(observable.Observable):
type: Literal[observable.ObservableType.bic] = observable.ObservableType.bic
type: Literal["bic"] = "bic"

@staticmethod
def is_valid(value: str) -> bool:
Expand Down
Loading
Loading