Skip to content

Commit

Permalink
Update models for pydantic v2 migration
Browse files Browse the repository at this point in the history
  • Loading branch information
t0mz06 committed Jan 15, 2024
1 parent b60a433 commit 4cb8155
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 125 deletions.
2 changes: 1 addition & 1 deletion src/pytmv1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _parse_data(raw_response: Response, class_: Type[R]) -> R:
etag=raw_response.headers.get("ETag", ""),
)
if class_ == BaseTaskResp:
resp_class = task_action(raw_response.json()["action"]).resp_class
resp_class = task_action(raw_response.json()["action"]).class_
class_ = resp_class if resp_class else class_
return class_(**raw_response.json())
if "application" in content_type and class_ == BytesResp:
Expand Down
116 changes: 56 additions & 60 deletions src/pytmv1/model/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from typing import Any, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel as PydanticBaseModel
from pydantic import Field
from pydantic import ConfigDict, Field
from pydantic import RootModel as PydanticRootModel
from pydantic import model_validator
from pydantic.alias_generators import to_camel

from .enums import (
Expand All @@ -25,37 +26,19 @@
Status,
)

CFG = ConfigDict(alias_generator=to_camel, populate_by_name=True)

class BaseModel(PydanticBaseModel):
def __init__(self, **data: Any):
super().__init__(**data)

class Config:
alias_generator = to_camel
populate_by_name = True
class BaseModel(PydanticBaseModel):
model_config = CFG


class RootModel(PydanticRootModel[List[int]]):
class Config:
alias_generator = to_camel
populate_by_name = True
model_config = CFG


class BaseConsumable(BaseModel):
def __init__(self, **data: Any):
super().__init__(**data)


def _get_task_id(headers: List[Dict[str, str]]) -> Optional[str]:
task_id: str = next(
(
h.get("value", "")
for h in headers
if "Operation-Location" == h.get("name", "")
),
"",
).split("/")[-1]
return task_id if task_id != "" else None
...


class Account(BaseModel):
Expand Down Expand Up @@ -204,25 +187,18 @@ class Error(BaseModel):
message: Optional[str] = None
number: Optional[int] = None

def __init__(self, **data: Any):
super().__init__(**data)


class ExceptionObject(BaseConsumable):
value: str
type: ObjectType
last_modified_date_time: str
description: Optional[str] = None

def __init__(self, **data: str) -> None:
super().__init__(value=self._obj_value(data), **data)

@staticmethod
def _obj_value(args: Dict[str, str]) -> str:
obj_value: Optional[str] = args.get(args.get("type", ""))
if obj_value is None:
raise ValueError("Object value not found")
return obj_value
@model_validator(mode="before")
@classmethod
def _map_data(cls, data: Dict[str, str]) -> Dict[str, str]:
data["value"] = data[data["type"]]
return data


class ImpactScope(BaseModel):
Expand Down Expand Up @@ -272,35 +248,39 @@ class MsData(BaseModel):
status: int
task_id: Optional[str] = None

def __init__(self, **data: Any):
super().__init__(
taskId=_get_task_id(data.pop("headers", {})),
**data,
)
@model_validator(mode="before")
@classmethod
def map_task_id(cls, data: Dict[str, Any]) -> Dict[str, Any]:
data["task_id"] = _get_task_id(data)
return data


class MsDataUrl(MsData):
url: str
id: Optional[str] = None
digest: Optional[Digest] = None

def __init__(self, **data: Any):
@model_validator(mode="before")
@classmethod
def map_data(cls, data: Dict[str, Any]) -> Dict[str, Any]:
data.update(data.pop("body", {}))
super().__init__(**data)
return data


class MsError(Error):
extra: Dict[str, str] = {}
task_id: Optional[str] = None

def __init__(self, **data: Any):
@model_validator(mode="before")
@classmethod
def map_data(cls, data: Dict[str, Any]) -> Dict[str, Any]:
data.update(data.pop("body", {}))
data.update(data.pop("error", {}))
super().__init__(
extra={"url": data.pop("url", "")},
taskId=_get_task_id(data.pop("headers", {})),
**data,
)
url = data.pop("url", None)
data["task_id"] = _get_task_id(data)
if url:
data["extra"] = {"url": url}
return data


class MsStatus(RootModel):
Expand Down Expand Up @@ -328,17 +308,13 @@ class SandboxSuspiciousObject(BaseModel):
type: ObjectType
value: str

def __init__(self, **data: Any):
obj: Tuple[str, str] = self._map(data)
super().__init__(type=obj[0], value=obj[1], **data)

@staticmethod
def _map(args: Dict[str, str]) -> Tuple[str, str]:
return {
(k, v)
for k, v in args.items()
if k in map(lambda ot: ot.value, ObjectType)
}.pop()
@model_validator(mode="before")
@classmethod
def map_data(cls, data: Dict[str, str]) -> Dict[str, str]:
obj = get_object(data)
data["type"] = obj[0]
data["value"] = obj[1]
return data


class Script(BaseConsumable):
Expand Down Expand Up @@ -371,3 +347,23 @@ class TiIndicator(Indicator):
matched_indicator_pattern_ids: List[str]
first_seen_date_times: List[str]
last_seen_date_times: List[str]


def get_object(data: Dict[str, str]) -> Tuple[str, str]:
for k, v in data.items():
if k in map(lambda ot: ot.value, ObjectType):
return k, v
raise ValueError("Could not find ObjectType")


def _get_task_id(data: Dict[str, Any]) -> Optional[str]:
return next(
map(
lambda header: header.get("value", "").split("/")[-1],
filter(
lambda header: "Operation-Location" == header.get("name", ""),
data.pop("headers", []),
),
),
None,
)
40 changes: 12 additions & 28 deletions src/pytmv1/model/responses.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
from __future__ import annotations

from enum import Enum
from typing import (
Any,
Dict,
Generic,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union

from pydantic import Field
from pydantic import Field, model_validator

from .commons import (
Account,
Expand All @@ -32,6 +22,7 @@
Script,
SuspiciousObject,
TiAlert,
get_object,
)
from .enums import (
ObjectType,
Expand Down Expand Up @@ -71,9 +62,6 @@ class BaseTaskResp(BaseStatusResponse):
description: Optional[str] = None
account: Optional[str] = None

def __init__(self, **data: Any):
super().__init__(**data)


MR = TypeVar("MR", bound=BaseMultiResponse[Any])
R = TypeVar("R", bound=BaseResponse)
Expand Down Expand Up @@ -103,17 +91,13 @@ class BlockListTaskResp(BaseTaskResp):
type: ObjectType
value: str

def __init__(self, **data: Any):
obj: Tuple[str, str] = self._map(data)
super().__init__(type=obj[0], value=obj[1], **data)

@staticmethod
def _map(args: Dict[str, str]) -> Tuple[str, str]:
return {
(k, v)
for k, v in args.items()
if k in map(lambda ot: ot.value, ObjectType)
}.pop()
@model_validator(mode="before")
@classmethod
def map_data(cls, data: Dict[str, str]) -> Dict[str, str]:
obj = get_object(data)
data["type"] = obj[0]
data["value"] = obj[1]
return data


class BytesResp(BaseResponse):
Expand Down Expand Up @@ -286,6 +270,6 @@ class TaskAction(Enum):
RUN_OS_QUERY = ("runOsquery", None)
RUN_YARA_RULES = ("runYaraRules", None)

def __init__(self, action: str, resp_class: Optional[Type[T]]):
def __init__(self, action: str, class_: Optional[Type[T]]):
self.action = action
self.resp_class = resp_class
self.class_ = class_
34 changes: 17 additions & 17 deletions tests/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def text(self) -> str:


def sae_alert():
return SaeAlert.construct(
return SaeAlert.model_construct(
id="1",
investigationStatus=InvestigationStatus.NEW,
model="Possible Credential Dumping via Registry",
Expand All @@ -40,32 +40,32 @@ def sae_alert():
description="description",
workbenchLink="https://THE_WORKBENCH_URL",
score=64,
impactScope=ImpactScope.construct(
impactScope=ImpactScope.model_construct(
desktopCount=1,
serverCount=0,
accountCount=1,
emailAddressCount=0,
entities=[
Entity.construct(
entity_value=HostInfo.construct(
Entity.model_construct(
entity_value=HostInfo.model_construct(
name="host", ips=["1.1.1.1", "2.2.2.2"]
)
)
],
),
indicators=[
Indicator.construct(
Indicator.model_construct(
provenance=["Alert"],
value=HostInfo.construct(
value=HostInfo.model_construct(
name="host", ips=["1.1.1.1", "2.2.2.2"]
),
)
],
matchedRules=[
MatchedRule.construct(
MatchedRule.model_construct(
name="Potential Credential Dumping via Registry",
matchedFilters=[
MatchedFilter.construct(
MatchedFilter.model_construct(
name="Possible Credential Dumping via Registry Hive",
mitreTechniqueIds=[
"V9.T1003.004",
Expand All @@ -80,7 +80,7 @@ def sae_alert():


def ti_alert():
return TiAlert.construct(
return TiAlert.model_construct(
id="1",
investigationStatus=InvestigationStatus.NEW,
model="Threat Intelligence Sweeping",
Expand All @@ -94,38 +94,38 @@ def ti_alert():
reportLink="https://THE_TI_REPORT_URL",
createdBy="n/a",
score=42,
impactScope=ImpactScope.construct(
impactScope=ImpactScope.model_construct(
desktopCount=1,
serverCount=0,
accountCount=1,
emailAddressCount=0,
entities=[
Entity.construct(
entity_value=HostInfo.construct(
Entity.model_construct(
entity_value=HostInfo.model_construct(
name="host", ips=["1.1.1.1", "2.2.2.2"]
)
)
],
),
indicators=[
Indicator.construct(
Indicator.model_construct(
provenance=["Alert"],
value=HostInfo.construct(
value=HostInfo.model_construct(
name="host", ips=["1.1.1.1", "2.2.2.2"]
),
)
],
matchedIndicatorPatterns=[
MatchedIndicatorPattern.construct(
MatchedIndicatorPattern.model_construct(
tags=["STIX2.malicious-activity"],
pattern="[file:name = 'goog-phish-proto-1.vlpset']",
)
],
matchedRules=[
MatchedRule.construct(
MatchedRule.model_construct(
name="Potential Credential Dumping via Registry",
matchedFilters=[
MatchedFilter.construct(
MatchedFilter.model_construct(
name="Possible Credential Dumping via Registry Hive",
mitreTechniqueIds=[
"V9.T1003.004",
Expand Down
Loading

0 comments on commit 4cb8155

Please sign in to comment.