Skip to content

Commit

Permalink
Merge pull request #28 from trendmicro/fix/mypy_error
Browse files Browse the repository at this point in the history
Fixed mypy generic type error
  • Loading branch information
t0mz06 authored Jul 24, 2024
2 parents 310cccc + 1382701 commit c81073a
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/pytmv1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
ScriptType,
Severity,
Status,
TaskAction,
)
from .model.request import (
AccountRequest,
Expand Down Expand Up @@ -114,7 +115,6 @@
SandboxSubmissionStatusResp,
SandboxSubmitUrlTaskResp,
SubmitFileToSandboxResp,
TaskAction,
TerminateProcessTaskResp,
TextResp,
)
Expand Down
14 changes: 6 additions & 8 deletions src/pytmv1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import AnyHttpUrl, TypeAdapter
from requests import PreparedRequest, Request, Response

from . import utils
from .__about__ import __version__
from .adapter import HTTPAdapter
from .exception import (
Expand All @@ -20,7 +21,7 @@
ServerTextError,
)
from .model.common import Error, MsError, MsStatus
from .model.enum import Api, HttpMethod, Status
from .model.enum import Api, HttpMethod, Status, TaskAction
from .model.request import EndpointRequest
from .model.response import (
MR,
Expand All @@ -40,7 +41,6 @@
S,
SandboxSubmissionStatusResp,
T,
TaskAction,
TextResp,
)
from .result import multi_result, result
Expand Down Expand Up @@ -304,8 +304,10 @@ def _parse_data(raw_response: Response, class_: Type[R]) -> R:
etag=raw_response.headers.get("ETag", ""),
)
if class_ == BaseTaskResp:
resp_class = task_action(raw_response.json()["action"]).class_
class_ = resp_class if resp_class else class_
resp_class: Type[BaseTaskResp] = utils.task_action_resp_class(
TaskAction(raw_response.json()["action"])
)
class_ = resp_class if issubclass(resp_class, class_) else class_
return class_(**raw_response.json())
if "application" in content_type and class_ == BytesResp:
log.debug("Parsing binary response")
Expand All @@ -319,10 +321,6 @@ def _parse_data(raw_response: Response, class_: Type[R]) -> R:
raise ParseModelError(class_.__name__, raw_response)


def task_action(action_name: str) -> TaskAction:
return next(filter(lambda ta: action_name == ta.action, TaskAction))


def _parse_html(html: str) -> str:
log.info("Parsing html response [Html=%s]", html)
soup = BeautifulSoup(html, "html.parser")
Expand Down
28 changes: 28 additions & 0 deletions src/pytmv1/model/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,31 @@ class Status(str, Enum):
RUNNING = "running"
SUCCEEDED = "succeeded"
WAIT_FOR_APPROVAL = "waitForApproval"


class TaskAction(str, Enum):
COLLECT_FILE = "collectFile"
COLLECT_EVIDENCE = "collectEvidence"
COLLECT_NETWORK_ANALYSIS_PACKAGE = "collectNetworkAnalysisPackage"
ISOLATE_ENDPOINT = "isolate"
ISOLATE_ENDPOINT_MULTIPLE = "isolateForMultiple"
RESTORE_ENDPOINT = "restoreIsolate"
RESTORE_ENDPOINT_MULTIPLE = "restoreIsolateForMultiple"
TERMINATE_PROCESS = "terminateProcess"
DUMP_PROCESS_MEMORY = "dumpProcessMemory"
QUARANTINE_MESSAGE = "quarantineMessage"
DELETE_MESSAGE = "deleteMessage"
RESTORE_MESSAGE = "restoreMessage"
BLOCK_SUSPICIOUS = "block"
REMOVE_SUSPICIOUS = "restoreBlock"
RESET_PASSWORD = "resetPassword"
SUBMIT_SANDBOX = "submitSandbox"
ENABLE_ACCOUNT = "enableAccount"
DISABLE_ACCOUNT = "disableAccount"
FORCE_SIGN_OUT = "forceSignOut"
REMOTE_SHELL = "remoteShell"
RUN_INVESTIGATION_KIT = "runInvestigationKit"
RUN_CUSTOM_SCRIPT = "runCustomScript"
RUN_CUSTOM_SCRIPT_MULTIPLE = "runCustomScriptForMultiple"
RUN_OS_QUERY = "runOsquery"
RUN_YARA_RULES = "runYaraRules"
35 changes: 1 addition & 34 deletions src/pytmv1/model/response.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union

from pydantic import Field, model_validator

Expand Down Expand Up @@ -282,35 +281,3 @@ class TerminateProcessTaskResp(BaseTaskResp):

class TextResp(BaseResponse):
text: str


class TaskAction(Enum):
COLLECT_FILE = ("collectFile", CollectFileTaskResp)
COLLECT_EVIDENCE = ("collectEvidence", None)
COLLECT_NETWORK_ANALYSIS_PACKAGE = ("collectNetworkAnalysisPackage", None)
ISOLATE_ENDPOINT = ("isolate", EndpointTaskResp)
ISOLATE_ENDPOINT_MULTIPLE = ("isolateForMultiple", None)
RESTORE_ENDPOINT = ("restoreIsolate", EndpointTaskResp)
RESTORE_ENDPOINT_MULTIPLE = ("restoreIsolateForMultiple", None)
TERMINATE_PROCESS = ("terminateProcess", TerminateProcessTaskResp)
DUMP_PROCESS_MEMORY = ("dumpProcessMemory", None)
QUARANTINE_MESSAGE = ("quarantineMessage", EmailMessageTaskResp)
DELETE_MESSAGE = ("deleteMessage", EmailMessageTaskResp)
RESTORE_MESSAGE = ("restoreMessage", EmailMessageTaskResp)
BLOCK_SUSPICIOUS = ("block", BlockListTaskResp)
REMOVE_SUSPICIOUS = ("restoreBlock", BlockListTaskResp)
RESET_PASSWORD = ("resetPassword", AccountTaskResp)
SUBMIT_SANDBOX = ("submitSandbox", SandboxSubmitUrlTaskResp)
ENABLE_ACCOUNT = ("enableAccount", AccountTaskResp)
DISABLE_ACCOUNT = ("disableAccount", AccountTaskResp)
FORCE_SIGN_OUT = ("forceSignOut", AccountTaskResp)
REMOTE_SHELL = ("remoteShell", None)
RUN_INVESTIGATION_KIT = ("runInvestigationKit", None)
RUN_CUSTOM_SCRIPT = ("runCustomScript", CustomScriptTaskResp)
RUN_CUSTOM_SCRIPT_MULTIPLE = ("runCustomScriptForMultiple", None)
RUN_OS_QUERY = ("runOsquery", None)
RUN_YARA_RULES = ("runYaraRules", None)

def __init__(self, action: str, class_: Optional[Type[T]]):
self.action = action
self.class_ = class_
39 changes: 37 additions & 2 deletions src/pytmv1/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,44 @@
import base64
import re
from typing import Any, Dict, List, Optional, Pattern
from typing import Any, Dict, List, Optional, Pattern, Type

from .model.enum import QueryOp, SearchMode
from .model.enum import QueryOp, SearchMode, TaskAction
from .model.request import ObjectRequest, SuspiciousObjectRequest
from .model.response import (
AccountTaskResp,
BaseTaskResp,
BlockListTaskResp,
CollectFileTaskResp,
CustomScriptTaskResp,
EmailMessageTaskResp,
EndpointTaskResp,
SandboxSubmitUrlTaskResp,
TerminateProcessTaskResp,
)

MAC_ADDRESS_PATTERN: Pattern[str] = re.compile(
"^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$"
)
GUID_PATTERN: Pattern[str] = re.compile("^(\\w+-+){1,5}\\w+$")

TASK_ACTION_MAP: Dict[TaskAction, Type[BaseTaskResp]] = {
TaskAction.COLLECT_FILE: CollectFileTaskResp,
TaskAction.ISOLATE_ENDPOINT: EndpointTaskResp,
TaskAction.RESTORE_ENDPOINT: EndpointTaskResp,
TaskAction.TERMINATE_PROCESS: TerminateProcessTaskResp,
TaskAction.QUARANTINE_MESSAGE: EmailMessageTaskResp,
TaskAction.DELETE_MESSAGE: EmailMessageTaskResp,
TaskAction.RESTORE_MESSAGE: EmailMessageTaskResp,
TaskAction.BLOCK_SUSPICIOUS: BlockListTaskResp,
TaskAction.REMOVE_SUSPICIOUS: BlockListTaskResp,
TaskAction.RESET_PASSWORD: AccountTaskResp,
TaskAction.SUBMIT_SANDBOX: SandboxSubmitUrlTaskResp,
TaskAction.ENABLE_ACCOUNT: AccountTaskResp,
TaskAction.DISABLE_ACCOUNT: AccountTaskResp,
TaskAction.FORCE_SIGN_OUT: AccountTaskResp,
TaskAction.RUN_CUSTOM_SCRIPT: CustomScriptTaskResp,
}


def _build_query(
op: QueryOp, header: str, fields: Dict[str, str]
Expand Down Expand Up @@ -122,3 +151,9 @@ def tmv1_activity_query(op: QueryOp, fields: Dict[str, str]) -> Dict[str, str]:

def filter_query(op: QueryOp, fields: Dict[str, str]) -> Dict[str, str]:
return _build_query(op, "filter", fields)


def task_action_resp_class(
task_action: TaskAction,
) -> Type[BaseTaskResp]:
return TASK_ACTION_MAP.get(task_action, BaseTaskResp)

0 comments on commit c81073a

Please sign in to comment.