From 138270144dbfb0a310b34174de43b911e6a0f118 Mon Sep 17 00:00:00 2001 From: Thomas Legros Date: Wed, 24 Jul 2024 16:23:01 +0200 Subject: [PATCH] Fixed mypy generic type error --- src/pytmv1/__init__.py | 2 +- src/pytmv1/core.py | 14 ++++++------- src/pytmv1/model/enum.py | 28 ++++++++++++++++++++++++++ src/pytmv1/model/response.py | 35 +------------------------------- src/pytmv1/utils.py | 39 ++++++++++++++++++++++++++++++++++-- 5 files changed, 73 insertions(+), 45 deletions(-) diff --git a/src/pytmv1/__init__.py b/src/pytmv1/__init__.py index f728028..47a0ac7 100755 --- a/src/pytmv1/__init__.py +++ b/src/pytmv1/__init__.py @@ -64,6 +64,7 @@ ScriptType, Severity, Status, + TaskAction, ) from .model.request import ( AccountRequest, @@ -114,7 +115,6 @@ SandboxSubmissionStatusResp, SandboxSubmitUrlTaskResp, SubmitFileToSandboxResp, - TaskAction, TerminateProcessTaskResp, TextResp, ) diff --git a/src/pytmv1/core.py b/src/pytmv1/core.py index ce807d0..e0a67ec 100755 --- a/src/pytmv1/core.py +++ b/src/pytmv1/core.py @@ -10,6 +10,7 @@ from pydantic import AnyHttpUrl, TypeAdapter from requests import PreparedRequest, Request, Response +from . import utils from .__about__ import __version__ from .adapter import HTTPAdapter from .exception import ( @@ -20,7 +21,7 @@ ServerTextError, ) from .model.common import Error, MsError, MsStatus -from .model.enum import Api, HttpMethod, Status +from .model.enum import Api, HttpMethod, Status, TaskAction from .model.request import EndpointRequest from .model.response import ( MR, @@ -40,7 +41,6 @@ S, SandboxSubmissionStatusResp, T, - TaskAction, TextResp, ) from .result import multi_result, result @@ -304,8 +304,10 @@ def _parse_data(raw_response: Response, class_: Type[R]) -> R: etag=raw_response.headers.get("ETag", ""), ) if class_ == BaseTaskResp: - resp_class = task_action(raw_response.json()["action"]).class_ - class_ = resp_class if resp_class else class_ + resp_class: Type[BaseTaskResp] = utils.task_action_resp_class( + TaskAction(raw_response.json()["action"]) + ) + class_ = resp_class if issubclass(resp_class, class_) else class_ return class_(**raw_response.json()) if "application" in content_type and class_ == BytesResp: log.debug("Parsing binary response") @@ -319,10 +321,6 @@ def _parse_data(raw_response: Response, class_: Type[R]) -> R: raise ParseModelError(class_.__name__, raw_response) -def task_action(action_name: str) -> TaskAction: - return next(filter(lambda ta: action_name == ta.action, TaskAction)) - - def _parse_html(html: str) -> str: log.info("Parsing html response [Html=%s]", html) soup = BeautifulSoup(html, "html.parser") diff --git a/src/pytmv1/model/enum.py b/src/pytmv1/model/enum.py index 979a619..1a13f95 100644 --- a/src/pytmv1/model/enum.py +++ b/src/pytmv1/model/enum.py @@ -299,3 +299,31 @@ class Status(str, Enum): RUNNING = "running" SUCCEEDED = "succeeded" WAIT_FOR_APPROVAL = "waitForApproval" + + +class TaskAction(str, Enum): + COLLECT_FILE = "collectFile" + COLLECT_EVIDENCE = "collectEvidence" + COLLECT_NETWORK_ANALYSIS_PACKAGE = "collectNetworkAnalysisPackage" + ISOLATE_ENDPOINT = "isolate" + ISOLATE_ENDPOINT_MULTIPLE = "isolateForMultiple" + RESTORE_ENDPOINT = "restoreIsolate" + RESTORE_ENDPOINT_MULTIPLE = "restoreIsolateForMultiple" + TERMINATE_PROCESS = "terminateProcess" + DUMP_PROCESS_MEMORY = "dumpProcessMemory" + QUARANTINE_MESSAGE = "quarantineMessage" + DELETE_MESSAGE = "deleteMessage" + RESTORE_MESSAGE = "restoreMessage" + BLOCK_SUSPICIOUS = "block" + REMOVE_SUSPICIOUS = "restoreBlock" + RESET_PASSWORD = "resetPassword" + SUBMIT_SANDBOX = "submitSandbox" + ENABLE_ACCOUNT = "enableAccount" + DISABLE_ACCOUNT = "disableAccount" + FORCE_SIGN_OUT = "forceSignOut" + REMOTE_SHELL = "remoteShell" + RUN_INVESTIGATION_KIT = "runInvestigationKit" + RUN_CUSTOM_SCRIPT = "runCustomScript" + RUN_CUSTOM_SCRIPT_MULTIPLE = "runCustomScriptForMultiple" + RUN_OS_QUERY = "runOsquery" + RUN_YARA_RULES = "runYaraRules" diff --git a/src/pytmv1/model/response.py b/src/pytmv1/model/response.py index 9207cbc..7a95694 100644 --- a/src/pytmv1/model/response.py +++ b/src/pytmv1/model/response.py @@ -1,7 +1,6 @@ from __future__ import annotations -from enum import Enum -from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, TypeVar, Union from pydantic import Field, model_validator @@ -282,35 +281,3 @@ class TerminateProcessTaskResp(BaseTaskResp): class TextResp(BaseResponse): text: str - - -class TaskAction(Enum): - COLLECT_FILE = ("collectFile", CollectFileTaskResp) - COLLECT_EVIDENCE = ("collectEvidence", None) - COLLECT_NETWORK_ANALYSIS_PACKAGE = ("collectNetworkAnalysisPackage", None) - ISOLATE_ENDPOINT = ("isolate", EndpointTaskResp) - ISOLATE_ENDPOINT_MULTIPLE = ("isolateForMultiple", None) - RESTORE_ENDPOINT = ("restoreIsolate", EndpointTaskResp) - RESTORE_ENDPOINT_MULTIPLE = ("restoreIsolateForMultiple", None) - TERMINATE_PROCESS = ("terminateProcess", TerminateProcessTaskResp) - DUMP_PROCESS_MEMORY = ("dumpProcessMemory", None) - QUARANTINE_MESSAGE = ("quarantineMessage", EmailMessageTaskResp) - DELETE_MESSAGE = ("deleteMessage", EmailMessageTaskResp) - RESTORE_MESSAGE = ("restoreMessage", EmailMessageTaskResp) - BLOCK_SUSPICIOUS = ("block", BlockListTaskResp) - REMOVE_SUSPICIOUS = ("restoreBlock", BlockListTaskResp) - RESET_PASSWORD = ("resetPassword", AccountTaskResp) - SUBMIT_SANDBOX = ("submitSandbox", SandboxSubmitUrlTaskResp) - ENABLE_ACCOUNT = ("enableAccount", AccountTaskResp) - DISABLE_ACCOUNT = ("disableAccount", AccountTaskResp) - FORCE_SIGN_OUT = ("forceSignOut", AccountTaskResp) - REMOTE_SHELL = ("remoteShell", None) - RUN_INVESTIGATION_KIT = ("runInvestigationKit", None) - RUN_CUSTOM_SCRIPT = ("runCustomScript", CustomScriptTaskResp) - RUN_CUSTOM_SCRIPT_MULTIPLE = ("runCustomScriptForMultiple", None) - RUN_OS_QUERY = ("runOsquery", None) - RUN_YARA_RULES = ("runYaraRules", None) - - def __init__(self, action: str, class_: Optional[Type[T]]): - self.action = action - self.class_ = class_ diff --git a/src/pytmv1/utils.py b/src/pytmv1/utils.py index 3ef18b0..158af6a 100755 --- a/src/pytmv1/utils.py +++ b/src/pytmv1/utils.py @@ -1,15 +1,44 @@ import base64 import re -from typing import Any, Dict, List, Optional, Pattern +from typing import Any, Dict, List, Optional, Pattern, Type -from .model.enum import QueryOp, SearchMode +from .model.enum import QueryOp, SearchMode, TaskAction from .model.request import ObjectRequest, SuspiciousObjectRequest +from .model.response import ( + AccountTaskResp, + BaseTaskResp, + BlockListTaskResp, + CollectFileTaskResp, + CustomScriptTaskResp, + EmailMessageTaskResp, + EndpointTaskResp, + SandboxSubmitUrlTaskResp, + TerminateProcessTaskResp, +) MAC_ADDRESS_PATTERN: Pattern[str] = re.compile( "^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$" ) GUID_PATTERN: Pattern[str] = re.compile("^(\\w+-+){1,5}\\w+$") +TASK_ACTION_MAP: Dict[TaskAction, Type[BaseTaskResp]] = { + TaskAction.COLLECT_FILE: CollectFileTaskResp, + TaskAction.ISOLATE_ENDPOINT: EndpointTaskResp, + TaskAction.RESTORE_ENDPOINT: EndpointTaskResp, + TaskAction.TERMINATE_PROCESS: TerminateProcessTaskResp, + TaskAction.QUARANTINE_MESSAGE: EmailMessageTaskResp, + TaskAction.DELETE_MESSAGE: EmailMessageTaskResp, + TaskAction.RESTORE_MESSAGE: EmailMessageTaskResp, + TaskAction.BLOCK_SUSPICIOUS: BlockListTaskResp, + TaskAction.REMOVE_SUSPICIOUS: BlockListTaskResp, + TaskAction.RESET_PASSWORD: AccountTaskResp, + TaskAction.SUBMIT_SANDBOX: SandboxSubmitUrlTaskResp, + TaskAction.ENABLE_ACCOUNT: AccountTaskResp, + TaskAction.DISABLE_ACCOUNT: AccountTaskResp, + TaskAction.FORCE_SIGN_OUT: AccountTaskResp, + TaskAction.RUN_CUSTOM_SCRIPT: CustomScriptTaskResp, +} + def _build_query( op: QueryOp, header: str, fields: Dict[str, str] @@ -122,3 +151,9 @@ def tmv1_activity_query(op: QueryOp, fields: Dict[str, str]) -> Dict[str, str]: def filter_query(op: QueryOp, fields: Dict[str, str]) -> Dict[str, str]: return _build_query(op, "filter", fields) + + +def task_action_resp_class( + task_action: TaskAction, +) -> Type[BaseTaskResp]: + return TASK_ACTION_MAP.get(task_action, BaseTaskResp)