diff --git a/Solutions/AbnormalSecurity/Data Connectors/SentinelFunctionsOrchestrator/__init__.py b/Solutions/AbnormalSecurity/Data Connectors/SentinelFunctionsOrchestrator/__init__.py index 92f58c4f5d5..bccc9e730d7 100644 --- a/Solutions/AbnormalSecurity/Data Connectors/SentinelFunctionsOrchestrator/__init__.py +++ b/Solutions/AbnormalSecurity/Data Connectors/SentinelFunctionsOrchestrator/__init__.py @@ -11,10 +11,19 @@ import os import re + import azure.durable_functions as df from .soar_connector_async import AbnormalSoarConnectorAsync from .sentinel_connector_async import AzureSentinelConnectorAsync +from .soar_connector_async_v2 import get_cases, get_threats +from .utils import ( + get_context, + should_use_v2_logic, + set_date_on_entity, + TIME_FORMAT, + Resource, +) RESET_ORCHESTRATION = os.environ.get("RESET_OPERATION", "false") PERSIST_TO_SENTINEL = os.environ.get("PERSIST_TO_SENTINEL", "true") @@ -44,13 +53,29 @@ def orchestrator_function(context: df.DurableOrchestrationContext): logging.info(f"Retrieved stored cases datetime: {stored_cases_datetime}") current_datetime = datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") - + + if should_use_v2_logic(): + logging.info("Using v2 fetching logic which accounts for Eventual consistentcy") + asyncio.run( + fetch_and_store_abnormal_data_v2( + context, + stored_threats_datetime, + stored_cases_datetime, + current_datetime, + ) + ) + logging.info("V2 orchestration finished") + return + else: + logging.info("Going with legacy logic") + asyncio.run(transfer_abnormal_data_to_sentinel(stored_threats_datetime, stored_cases_datetime, current_datetime, context)) logging.info("Orchestrator execution finished") def should_reset_date_params(): return RESET_ORCHESTRATION == "true" + async def transfer_abnormal_data_to_sentinel(stored_threats_datetime,stored_cases_datetime, current_datetime, context): threats_date_filter = {"gte_datetime": stored_threats_datetime, "lte_datetime": current_datetime} cases_date_filter = {"gte_datetime": stored_cases_datetime, "lte_datetime": current_datetime} @@ -83,4 +108,60 @@ async def consume(sentinel_connector, queue): logging.error(f"Sentinel send request Failed. Err: {e}") queue.task_done() +async def fetch_and_store_abnormal_data_v2( + context: df.DurableOrchestrationContext, + stored_threats_datetime: str, + stored_cases_datetime: str, +): + queue = asyncio.Queue() + try: + threats_ctx = get_context(stored_date_time=stored_threats_datetime) + cases_ctx = get_context(stored_date_time=stored_cases_datetime) + + logging.info( + f"Timestamps (stored, current) \ + threats: ({stored_threats_datetime}, {threats_ctx.CURRENT_TIME}); \ + cases: ({stored_cases_datetime}, {cases_ctx.CURRENT_TIME})", + ) + + # Execute threats first and then cases as cases can error out with a 403. + await get_threats(ctx=threats_ctx, output_queue=queue) + + logging.info("Fetching threats completed") + + await get_cases(ctx=cases_ctx, output_queue=queue) + + logging.info("Fetching cases completed") + except Exception as e: + logging.error("Failed to process", exc_info=e) + finally: + set_date_on_entity( + context=context, + time=threats_ctx.CURRENT_TIME.strftime(TIME_FORMAT), + resource=Resource.threats, + ) + logging.info("Stored new threats date") + + set_date_on_entity( + context=context, + time=cases_ctx.CURRENT_TIME.strftime(TIME_FORMAT), + resource=Resource.cases, + ) + logging.info("Stored new cases date") + + if should_persist_data_to_sentinel(): + logging.info("Persisting to sentinel") + sentinel_connector = AzureSentinelConnectorAsync( + LOG_ANALYTICS_URI, SENTINEL_WORKSPACE_ID, SENTINEL_SHARED_KEY + ) + consumers = [ + asyncio.create_task(consume(sentinel_connector, queue)) for _ in range(3) + ] + await queue.join() # Implicitly awaits consumers, too + await asyncio.gather(**consumers) + for c in consumers: + c.cancel() + await sentinel_connector.flushall() + + main = df.Orchestrator.create(orchestrator_function) \ No newline at end of file diff --git a/Solutions/AbnormalSecurity/Data Connectors/SentinelFunctionsOrchestrator/soar_connector_async_v2.py b/Solutions/AbnormalSecurity/Data Connectors/SentinelFunctionsOrchestrator/soar_connector_async_v2.py new file mode 100644 index 00000000000..2bb7154c790 --- /dev/null +++ b/Solutions/AbnormalSecurity/Data Connectors/SentinelFunctionsOrchestrator/soar_connector_async_v2.py @@ -0,0 +1,261 @@ +import json +from urllib.parse import urlencode, urljoin +import aiohttp +import logging +import asyncio +import itertools +from typing import Dict, List +from utils import ( + OptionalEndTimeRange, + FilterParam, + MAP_RESOURCE_TO_LOGTYPE, + Resource, + TIME_FORMAT, + compute_intervals, + Context, + try_str_to_datetime, +) + + +def get_query_params( + filter_param: FilterParam, interval: OptionalEndTimeRange +) -> Dict[str, str]: + filter = filter_param.name + filter += f" gte {interval.start.strftime(TIME_FORMAT)}" + if interval.end is not None: + filter += f" lte {interval.end.strftime(TIME_FORMAT)}" + + return {"filter": filter} + + +def get_headers(ctx: Context) -> Dict[str, str]: + return { + "X-Abnormal-Trace-Id": str(ctx.TRACE_ID), + "Authorization": f"Bearer {ctx.API_TOKEN}", + "Soar-Integration-Origin": "AZURE SENTINEL", + "Azure-Sentinel-Version": "2024-09-15", + } + + +def compute_url(base_url: str, pathname: str, params: Dict[str, str]) -> str: + endpoint = urljoin(base_url, pathname) + + params_str = urlencode(params) + if params_str: + endpoint += f"?{params_str}" + + return endpoint + + +async def fetch_with_retries(url, retries=3, backoff=1, timeout=10, headers=None): + async def fetch(session, url): + async with session.get(url, headers=headers, timeout=timeout) as response: + if 500 <= response.status < 600: + raise aiohttp.ClientResponseError( + request_info=response.request_info, + history=response.history, + code=response.status, + message=response.reason, + headers=response.headers, + ) + # response.raise_for_status() + return json.loads(await response.text()) + + async with aiohttp.ClientSession() as session: + for attempt in range(1, retries + 1): + try: + response = await fetch(session, url) + return response + except aiohttp.ClientResponseError as e: + if 500 <= e.status < 600: + print(f"Attempt {attempt} failed with error: {e}") + if attempt == retries: + raise + else: + await asyncio.sleep(backoff**attempt) + else: + raise + except aiohttp.ClientError as e: + print(f"Request failed with non-retryable error: {e}") + raise + + +async def call_threat_campaigns_endpoint( + ctx: Context, interval: OptionalEndTimeRange, semaphore: asyncio.Semaphore +) -> List[str]: + async with semaphore: + params = get_query_params( + filter_param=FilterParam.latestTimeRemediated, interval=interval + ) + + threat_campaigns = set() + + nextPageNumber = 1 + while nextPageNumber: + params["pageNumber"] = nextPageNumber + endpoint = compute_url(ctx.BASE_URL, "/threats", params) + headers = get_headers(ctx) + + response = await fetch_with_retries(url=endpoint, headers=headers) + total = response["total"] + assert total >= 0 + + threat_campaigns.update( + [threat["threatId"] for threat in response.get("threats", [])] + ) + + nextPageNumber = response.get("nextPageNumber") + assert nextPageNumber is None or nextPageNumber > 0 + + if nextPageNumber is None or nextPageNumber > ctx.MAX_PAGE_NUMBER: + break + + return list(threat_campaigns) + + +async def call_cases_endpoint( + ctx: Context, interval: OptionalEndTimeRange, semaphore: asyncio.Semaphore +) -> List[str]: + async with semaphore: + params = get_query_params( + filter_param=FilterParam.customerVisibleTime, interval=interval + ) + + case_ids = set() + + nextPageNumber = 1 + while nextPageNumber: + params["pageNumber"] = nextPageNumber + endpoint = compute_url(ctx.BASE_URL, "/cases", params) + headers = get_headers(ctx) + + response = await fetch_with_retries(url=endpoint, headers=headers) + total = response["total"] + assert total >= 0 + + case_ids.update([case["caseId"] for case in response.get("cases", [])]) + + nextPageNumber = response.get("nextPageNumber") + assert nextPageNumber is None or nextPageNumber > 0 + + if nextPageNumber is None or nextPageNumber > ctx.MAX_PAGE_NUMBER: + break + + return list(case_ids) + + +async def call_single_threat_endpoint( + ctx: Context, threat_id: str, semaphore: asyncio.Semaphore +) -> List[str]: + async with semaphore: + endpoint = compute_url(ctx.BASE_URL, f"/threats/{threat_id}", params={}) + headers = get_headers(ctx) + + response = await fetch_with_retries(url=endpoint, headers=headers) + + filtered_messages = [] + for message in response["messages"]: + message_id = message["abxMessageId"] + remediation_time_str = message["remediationTimestamp"] + + remediation_time = try_str_to_datetime(remediation_time_str) + if ( + remediation_time >= ctx.CLIENT_FILTER_TIME_RANGE.start + and remediation_time < ctx.CLIENT_FILTER_TIME_RANGE.end + ): + filtered_messages.append(json.dumps(message, sort_keys=True)) + logging.debug(f"Successfully processed threat message: {message_id}") + else: + logging.debug(f"Skipped processing threat message: {message_id}") + + return filtered_messages + + +async def call_single_case_endpoint( + ctx: Context, case_id: str, semaphore: asyncio.Semaphore +) -> str: + async with semaphore: + endpoint = compute_url(ctx.BASE_URL, f"/cases/{case_id}", params={}) + headers = get_headers(ctx) + + response = await fetch_with_retries(url=endpoint, headers=headers) + + return json.dumps(response, sort_keys=True) + + +async def get_threats(ctx: Context, output_queue: asyncio.Queue) -> asyncio.Queue: + intervals = compute_intervals(ctx) + logging.info( + "Computed threats intervals\n" + + "\n".join(map(lambda x: f"{str(x.start)} : {str(x.end)}", intervals)) + ) + + assert len(intervals) <= 5, "Intervals more than 5" + semaphore = asyncio.Semaphore(ctx.NUM_CONCURRENCY) + + campaign_result = await asyncio.gather( + *[ + call_threat_campaigns_endpoint( + ctx=ctx, interval=interval, semaphore=semaphore + ) + for interval in intervals + ] + ) + threat_ids = set(itertools.chain(*campaign_result)) + + single_result = await asyncio.gather( + *[ + call_single_threat_endpoint( + ctx=ctx, threat_id=threat_id, semaphore=semaphore + ) + for threat_id in threat_ids + ] + ) + messages = set(itertools.chain(*single_result)) + + for message in messages: + record = (MAP_RESOURCE_TO_LOGTYPE[Resource.threats], json.loads(message)) + logging.debug(f"Inserting threat message record {record}") + await output_queue.put(record) + + return + + +async def get_cases(ctx: Context, output_queue: asyncio.Queue) -> asyncio.Queue: + intervals = compute_intervals(ctx) + logging.info( + "Computed cases intervals\n" + + "\n".join(map(lambda x: f"{str(x.start)} : {str(x.end)}", intervals)) + ) + + assert len(intervals) <= 5, "Intervals more than 5" + semaphore = asyncio.Semaphore(ctx.NUM_CONCURRENCY) + + result = await asyncio.gather( + *[ + call_cases_endpoint(ctx=ctx, interval=interval, semaphore=semaphore) + for interval in intervals + ] + ) + case_ids = set(itertools.chain(*result)) + + cases = await asyncio.gather( + *[ + call_single_case_endpoint(ctx=ctx, case_id=case_id, semaphore=semaphore) + for case_id in case_ids + ] + ) + + for case in cases: + loaded_case = json.loads(case) + record = (MAP_RESOURCE_TO_LOGTYPE[Resource.cases], loaded_case) + visible_time = try_str_to_datetime(loaded_case["customerVisibleTime"]) + if visible_time >= ctx.CLIENT_FILTER_TIME_RANGE.start and visible_time < ctx.CLIENT_FILTER_TIME_RANGE.end: + logging.debug(f"Inserting case record {record}") + await output_queue.put(record) + else: + logging.debug(f"Skipping case record {record}") + + return + + diff --git a/Solutions/AbnormalSecurity/Data Connectors/SentinelFunctionsOrchestrator/soar_connector_async_v2_local_run.py b/Solutions/AbnormalSecurity/Data Connectors/SentinelFunctionsOrchestrator/soar_connector_async_v2_local_run.py new file mode 100644 index 00000000000..b72325bd826 --- /dev/null +++ b/Solutions/AbnormalSecurity/Data Connectors/SentinelFunctionsOrchestrator/soar_connector_async_v2_local_run.py @@ -0,0 +1,100 @@ +import logging +import os +import asyncio +import time +from datetime import datetime, timedelta +from soar_connector_async_v2 import get_cases, get_threats +from utils import get_context, TIME_FORMAT + +def find_duplicates(arr): + from collections import Counter + + counts = Counter(arr) + return [item for item, count in counts.items() if count > 1] + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + os.environ["ABNORMAL_SECURITY_REST_API_TOKEN"] = "121" + os.environ["API_HOST"] = "http://localhost:3000" + os.environ["ABNORMAL_LAG_ON_BACKEND_SEC"] = "10" + os.environ["ABNORMAL_FREQUENCY_MIN"] = "1" + os.environ["ABNORMAL_LIMIT_MIN"] = "2" + + stored_threat_time = datetime.now() - timedelta(minutes=3) + stored_cases_time = datetime.now() - timedelta(minutes=3) + output_threats_queue = asyncio.Queue() + output_cases_queue = asyncio.Queue() + try: + while True: + threats_ctx = get_context(stored_date_time=stored_threat_time.strftime(TIME_FORMAT)) + logging.info( + f"Filtering messages in range {threats_ctx.CLIENT_FILTER_TIME_RANGE.start} : {threats_ctx.CLIENT_FILTER_TIME_RANGE.end}" + ) + asyncio.run(get_threats(ctx=threats_ctx, output_queue=output_threats_queue)) + + stored_threat_time = threats_ctx.CURRENT_TIME + logging.info(f"Sleeping for {threats_ctx.FREQUENCY.total_seconds()} seconds\n\n") + + + cases_ctx = get_context(stored_date_time=stored_cases_time.strftime(TIME_FORMAT)) + logging.info( + f"Filtering messages in range {cases_ctx.CLIENT_FILTER_TIME_RANGE.start} : {cases_ctx.CLIENT_FILTER_TIME_RANGE.end}" + ) + asyncio.run(get_cases(ctx=cases_ctx, output_queue=output_cases_queue)) + + stored_cases_time = cases_ctx.CURRENT_TIME + logging.info(f"Sleeping for {cases_ctx.FREQUENCY.total_seconds()} seconds\n\n") + time.sleep(cases_ctx.FREQUENCY.total_seconds()) + + + + + except KeyboardInterrupt: + pass + + idlist = [] + while not output_threats_queue.empty(): + current = output_threats_queue.get_nowait() + print(current) + idlist.append(current[1]["abxMessageId"]) + + idset = set(idlist) + maxid = max(idlist) + duplicates = find_duplicates(idlist) + missedids = list(filter(lambda x: x not in idset, list(range(1, maxid + 1)))) + + print("\n\n\nSummary of the operation") + + print("Ingested values", idlist) + print(f"Max ID: {maxid}") + print(f"Duplicates: {duplicates}") + print(f"Missed IDs: {missedids}") + + assert len(idset) == len(idlist), "Duplicates threats exist" + assert len(duplicates) == 0, "There are duplicates threats" + assert len(missedids) == 0, "There are missed threats IDs" + + + + idlist = [] + while not output_cases_queue.empty(): + current = output_cases_queue.get_nowait() + print(current) + idlist.append(current[1]["caseId"]) + + idset = set(idlist) + maxid = max(idlist) + duplicates = find_duplicates(idlist) + missedids = list(filter(lambda x: x not in idset, list(range(1, maxid + 1)))) + + print("\n\n\nSummary of the operation") + + print("Ingested values", idlist) + print(f"Max ID: {maxid}") + print(f"Duplicates: {duplicates}") + print(f"Missed IDs: {missedids}") + + assert len(idset) == len(idlist), "Duplicate cases exist" + assert len(duplicates) == 0, "There are duplicates cases" + assert len(missedids) == 0, "There are missed cases IDs" diff --git a/Solutions/AbnormalSecurity/Data Connectors/SentinelFunctionsOrchestrator/utils.py b/Solutions/AbnormalSecurity/Data Connectors/SentinelFunctionsOrchestrator/utils.py new file mode 100644 index 00000000000..fa6fd7438a9 --- /dev/null +++ b/Solutions/AbnormalSecurity/Data Connectors/SentinelFunctionsOrchestrator/utils.py @@ -0,0 +1,179 @@ +from datetime import datetime, timedelta +from enum import Enum +from typing import List, Optional +import os +from uuid import uuid4, UUID +from pydantic import BaseModel, model_validator +import azure.durable_functions as df + +TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" +TIME_FORMAT_WITHMS = "%Y-%m-%dT%H:%M:%S.%fZ" + + +def try_str_to_datetime(time: str) -> datetime: + try: + return datetime.strptime(time, TIME_FORMAT) + except Exception as _: + pass + return datetime.strptime(time, TIME_FORMAT_WITHMS) + + +class TimeRange(BaseModel): + start: datetime + end: datetime + + @model_validator(mode="before") + def check_start_less_than_end(cls, values): + start = values.get("start") + end = values.get("end") + + if start > end: + raise ValueError(f"Start time {start} is greater then end time {end}") + return values + + +class OptionalEndTimeRange(BaseModel): + start: datetime + end: Optional[datetime] + + @model_validator(mode="before") + def check_start_less_than_end(cls, values): + start = values.get("start") + end = values.get("end") + + if end is not None and start > end: + raise ValueError(f"Start time {start} is greater then end time {end}") + return values + + +class Context(BaseModel): + LAG_ON_BACKEND: timedelta + OUTAGE_TIME: timedelta + FREQUENCY: timedelta + LIMIT: timedelta + NUM_CONCURRENCY: int + MAX_PAGE_NUMBER: int + BASE_URL: str + API_TOKEN: str + TIME_RANGE: TimeRange + CLIENT_FILTER_TIME_RANGE: TimeRange + STORED_TIME: datetime + CURRENT_TIME: datetime + TRACE_ID: UUID + + +class Resource(Enum): + threats = 0 + cases = 1 + + +class FilterParam(Enum): + receivedTime = 0 + createdTime = 1 + firstObserved = 2 + latestTimeRemediated = 3 + customerVisibleTime = 4 + + +MAP_RESOURCE_TO_LOGTYPE = { + Resource.threats: "ABNORMAL_THREAT_MESSAGES", + Resource.cases: "ABNORMAL_CASES", +} + +MAP_RESOURCE_TO_ENTITY_VALUE = { + Resource.threats: "threats_date", + Resource.cases: "cases_date", +} + + +def compute_intervals(ctx: Context) -> List[OptionalEndTimeRange]: + """ + Function that returns for a time range [X, Y] + It returns an array of intervals of frequency size by accounting for lag_on_backend and outage_time. + timerange.start must be greater than 15 mins + [ + [X - lag_on_backend, X - lag_on_backend + 5] + ... + [Z, None] + ] + """ + timerange = ctx.TIME_RANGE + + start_time, current_time = timerange.start, timerange.end + print(f"Specified timerange: {start_time} : {current_time}") + + if current_time - start_time > ctx.OUTAGE_TIME: + start_time = current_time - ctx.OUTAGE_TIME + + assert current_time - start_time <= ctx.OUTAGE_TIME + + start = start_time.replace() - ctx.LAG_ON_BACKEND + current = current_time.replace() + + print(f"Modified timerange: {start} : {current}") + + assert current > start + + limit = ctx.LIMIT + add = ctx.FREQUENCY + + assert limit >= add + + intervals: List[OptionalEndTimeRange] = [] + while current - start > limit: + intervals.append(OptionalEndTimeRange(start=start, end=start + add)) + start = start + add + + intervals.append(OptionalEndTimeRange(start=start, end=None)) + + return intervals + + +def should_use_v2_logic() -> bool: + return bool(os.environ.get("ABNORMAL_ENABLE_V2_LOGIC")) + + +def get_context(stored_date_time: str) -> Context: + BASE_URL = os.environ.get("API_HOST", "https://api.abnormalplatform.com/v1") + API_TOKEN = os.environ["ABNORMAL_SECURITY_REST_API_TOKEN"] + OUTAGE_TIME = timedelta( + minutes=int(os.environ.get("ABNORMAL_OUTAGE_TIME_MIN", "15")) + ) + LAG_ON_BACKEND = timedelta( + seconds=int(os.environ.get("ABNORMAL_LAG_ON_BACKEND_SEC", "30")) + ) + FREQUENCY = timedelta(minutes=int(os.environ.get("ABNORMAL_FREQUENCY_MIN", "5"))) + LIMIT = timedelta(minutes=int(os.environ.get("ABNORMAL_LIMIT_MIN", "6"))) + NUM_CONCURRENCY = int(os.environ.get("ABNORMAL_NUM_CONCURRENCY", "5")) + MAX_PAGE_NUMBER = int(os.environ.get("ABNORMAL_MAX_PAGE_NUMBER", "3")) + + STORED_TIME = try_str_to_datetime(stored_date_time) + CURRENT_TIME = try_str_to_datetime(datetime.now().strftime(TIME_FORMAT)) + TIME_RANGE = TimeRange(start=STORED_TIME, end=CURRENT_TIME) + CLIENT_FILTER_TIME_RANGE = TimeRange( + start=STORED_TIME - LAG_ON_BACKEND, end=CURRENT_TIME - LAG_ON_BACKEND + ) + + return Context( + LAG_ON_BACKEND=LAG_ON_BACKEND, + OUTAGE_TIME=OUTAGE_TIME, + NUM_CONCURRENCY=NUM_CONCURRENCY, + FREQUENCY=FREQUENCY, + BASE_URL=BASE_URL, + API_TOKEN=API_TOKEN, + TIME_RANGE=TIME_RANGE, + CLIENT_FILTER_TIME_RANGE=CLIENT_FILTER_TIME_RANGE, + MAX_PAGE_NUMBER=MAX_PAGE_NUMBER, + STORED_TIME=STORED_TIME, + CURRENT_TIME=CURRENT_TIME, + LIMIT=LIMIT, + TRACE_ID=uuid4(), + ) + + +def set_date_on_entity( + context: df.DurableOrchestrationContext, time: str, resource: Resource +): + entity_value = MAP_RESOURCE_TO_ENTITY_VALUE[resource] + datetimeEntityId = df.EntityId("SoarDatetimeEntity", "latestDatetime") + context.signal_entity(datetimeEntityId, "set", {"type": entity_value, "date": time}) diff --git a/Solutions/AbnormalSecurity/Data Connectors/requirements.txt b/Solutions/AbnormalSecurity/Data Connectors/requirements.txt index a0a6779521b..0dcf11f7387 100644 --- a/Solutions/AbnormalSecurity/Data Connectors/requirements.txt +++ b/Solutions/AbnormalSecurity/Data Connectors/requirements.txt @@ -4,4 +4,5 @@ azure-functions==1.8.0 azure-functions-durable==1.1.3 -requests==2.26.0 \ No newline at end of file +requests==2.26.0 +pydantic==2.9.2 \ No newline at end of file