Skip to content

Commit

Permalink
finished sentinel refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rednithin committed Sep 30, 2024
1 parent e373814 commit 6021051
Show file tree
Hide file tree
Showing 5 changed files with 624 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,19 @@
import os
import re


import azure.durable_functions as df

from .soar_connector_async import AbnormalSoarConnectorAsync
from .sentinel_connector_async import AzureSentinelConnectorAsync
from .soar_connector_async_v2 import get_cases, get_threats
from .utils import (
get_context,
should_use_v2_logic,
set_date_on_entity,
TIME_FORMAT,
Resource,
)

RESET_ORCHESTRATION = os.environ.get("RESET_OPERATION", "false")
PERSIST_TO_SENTINEL = os.environ.get("PERSIST_TO_SENTINEL", "true")
Expand Down Expand Up @@ -44,13 +53,29 @@ def orchestrator_function(context: df.DurableOrchestrationContext):
logging.info(f"Retrieved stored cases datetime: {stored_cases_datetime}")

current_datetime = datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")


if should_use_v2_logic():
logging.info("Using v2 fetching logic which accounts for Eventual consistentcy")
asyncio.run(
fetch_and_store_abnormal_data_v2(
context,
stored_threats_datetime,
stored_cases_datetime,
current_datetime,
)
)
logging.info("V2 orchestration finished")
return
else:
logging.info("Going with legacy logic")

asyncio.run(transfer_abnormal_data_to_sentinel(stored_threats_datetime, stored_cases_datetime, current_datetime, context))
logging.info("Orchestrator execution finished")

def should_reset_date_params():
return RESET_ORCHESTRATION == "true"


async def transfer_abnormal_data_to_sentinel(stored_threats_datetime,stored_cases_datetime, current_datetime, context):
threats_date_filter = {"gte_datetime": stored_threats_datetime, "lte_datetime": current_datetime}
cases_date_filter = {"gte_datetime": stored_cases_datetime, "lte_datetime": current_datetime}
Expand Down Expand Up @@ -83,4 +108,60 @@ async def consume(sentinel_connector, queue):
logging.error(f"Sentinel send request Failed. Err: {e}")
queue.task_done()

async def fetch_and_store_abnormal_data_v2(
context: df.DurableOrchestrationContext,
stored_threats_datetime: str,
stored_cases_datetime: str,
):
queue = asyncio.Queue()
try:
threats_ctx = get_context(stored_date_time=stored_threats_datetime)
cases_ctx = get_context(stored_date_time=stored_cases_datetime)

logging.info(
f"Timestamps (stored, current) \
threats: ({stored_threats_datetime}, {threats_ctx.CURRENT_TIME}); \
cases: ({stored_cases_datetime}, {cases_ctx.CURRENT_TIME})",
)

# Execute threats first and then cases as cases can error out with a 403.
await get_threats(ctx=threats_ctx, output_queue=queue)

logging.info("Fetching threats completed")

await get_cases(ctx=cases_ctx, output_queue=queue)

logging.info("Fetching cases completed")
except Exception as e:
logging.error("Failed to process", exc_info=e)
finally:
set_date_on_entity(
context=context,
time=threats_ctx.CURRENT_TIME.strftime(TIME_FORMAT),
resource=Resource.threats,
)
logging.info("Stored new threats date")

set_date_on_entity(
context=context,
time=cases_ctx.CURRENT_TIME.strftime(TIME_FORMAT),
resource=Resource.cases,
)
logging.info("Stored new cases date")

if should_persist_data_to_sentinel():
logging.info("Persisting to sentinel")
sentinel_connector = AzureSentinelConnectorAsync(
LOG_ANALYTICS_URI, SENTINEL_WORKSPACE_ID, SENTINEL_SHARED_KEY
)
consumers = [
asyncio.create_task(consume(sentinel_connector, queue)) for _ in range(3)
]
await queue.join() # Implicitly awaits consumers, too
await asyncio.gather(**consumers)
for c in consumers:
c.cancel()
await sentinel_connector.flushall()


main = df.Orchestrator.create(orchestrator_function)
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import json
from urllib.parse import urlencode, urljoin
import aiohttp
import logging
import asyncio
import itertools
from typing import Dict, List
from utils import (
OptionalEndTimeRange,
FilterParam,
MAP_RESOURCE_TO_LOGTYPE,
Resource,
TIME_FORMAT,
compute_intervals,
Context,
try_str_to_datetime,
)


def get_query_params(
filter_param: FilterParam, interval: OptionalEndTimeRange
) -> Dict[str, str]:
filter = filter_param.name
filter += f" gte {interval.start.strftime(TIME_FORMAT)}"
if interval.end is not None:
filter += f" lte {interval.end.strftime(TIME_FORMAT)}"

return {"filter": filter}


def get_headers(ctx: Context) -> Dict[str, str]:
return {
"X-Abnormal-Trace-Id": str(ctx.TRACE_ID),
"Authorization": f"Bearer {ctx.API_TOKEN}",
"Soar-Integration-Origin": "AZURE SENTINEL",
"Azure-Sentinel-Version": "2024-09-15",
}


def compute_url(base_url: str, pathname: str, params: Dict[str, str]) -> str:
endpoint = urljoin(base_url, pathname)

params_str = urlencode(params)
if params_str:
endpoint += f"?{params_str}"

return endpoint


async def fetch_with_retries(url, retries=3, backoff=1, timeout=10, headers=None):
async def fetch(session, url):
async with session.get(url, headers=headers, timeout=timeout) as response:
if 500 <= response.status < 600:
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
code=response.status,
message=response.reason,
headers=response.headers,
)
# response.raise_for_status()
return json.loads(await response.text())

async with aiohttp.ClientSession() as session:
for attempt in range(1, retries + 1):
try:
response = await fetch(session, url)
return response
except aiohttp.ClientResponseError as e:
if 500 <= e.status < 600:
print(f"Attempt {attempt} failed with error: {e}")
if attempt == retries:
raise
else:
await asyncio.sleep(backoff**attempt)
else:
raise
except aiohttp.ClientError as e:
print(f"Request failed with non-retryable error: {e}")
raise


async def call_threat_campaigns_endpoint(
ctx: Context, interval: OptionalEndTimeRange, semaphore: asyncio.Semaphore
) -> List[str]:
async with semaphore:
params = get_query_params(
filter_param=FilterParam.latestTimeRemediated, interval=interval
)

threat_campaigns = set()

nextPageNumber = 1
while nextPageNumber:
params["pageNumber"] = nextPageNumber
endpoint = compute_url(ctx.BASE_URL, "/threats", params)
headers = get_headers(ctx)

response = await fetch_with_retries(url=endpoint, headers=headers)
total = response["total"]
assert total >= 0

threat_campaigns.update(
[threat["threatId"] for threat in response.get("threats", [])]
)

nextPageNumber = response.get("nextPageNumber")
assert nextPageNumber is None or nextPageNumber > 0

if nextPageNumber is None or nextPageNumber > ctx.MAX_PAGE_NUMBER:
break

return list(threat_campaigns)


async def call_cases_endpoint(
ctx: Context, interval: OptionalEndTimeRange, semaphore: asyncio.Semaphore
) -> List[str]:
async with semaphore:
params = get_query_params(
filter_param=FilterParam.customerVisibleTime, interval=interval
)

case_ids = set()

nextPageNumber = 1
while nextPageNumber:
params["pageNumber"] = nextPageNumber
endpoint = compute_url(ctx.BASE_URL, "/cases", params)
headers = get_headers(ctx)

response = await fetch_with_retries(url=endpoint, headers=headers)
total = response["total"]
assert total >= 0

case_ids.update([case["caseId"] for case in response.get("cases", [])])

nextPageNumber = response.get("nextPageNumber")
assert nextPageNumber is None or nextPageNumber > 0

if nextPageNumber is None or nextPageNumber > ctx.MAX_PAGE_NUMBER:
break

return list(case_ids)


async def call_single_threat_endpoint(
ctx: Context, threat_id: str, semaphore: asyncio.Semaphore
) -> List[str]:
async with semaphore:
endpoint = compute_url(ctx.BASE_URL, f"/threats/{threat_id}", params={})
headers = get_headers(ctx)

response = await fetch_with_retries(url=endpoint, headers=headers)

filtered_messages = []
for message in response["messages"]:
message_id = message["abxMessageId"]
remediation_time_str = message["remediationTimestamp"]

remediation_time = try_str_to_datetime(remediation_time_str)
if (
remediation_time >= ctx.CLIENT_FILTER_TIME_RANGE.start
and remediation_time < ctx.CLIENT_FILTER_TIME_RANGE.end
):
filtered_messages.append(json.dumps(message, sort_keys=True))
logging.debug(f"Successfully processed threat message: {message_id}")
else:
logging.debug(f"Skipped processing threat message: {message_id}")

return filtered_messages


async def call_single_case_endpoint(
ctx: Context, case_id: str, semaphore: asyncio.Semaphore
) -> str:
async with semaphore:
endpoint = compute_url(ctx.BASE_URL, f"/cases/{case_id}", params={})
headers = get_headers(ctx)

response = await fetch_with_retries(url=endpoint, headers=headers)

return json.dumps(response, sort_keys=True)


async def get_threats(ctx: Context, output_queue: asyncio.Queue) -> asyncio.Queue:
intervals = compute_intervals(ctx)
logging.info(
"Computed threats intervals\n"
+ "\n".join(map(lambda x: f"{str(x.start)} : {str(x.end)}", intervals))
)

assert len(intervals) <= 5, "Intervals more than 5"
semaphore = asyncio.Semaphore(ctx.NUM_CONCURRENCY)

campaign_result = await asyncio.gather(
*[
call_threat_campaigns_endpoint(
ctx=ctx, interval=interval, semaphore=semaphore
)
for interval in intervals
]
)
threat_ids = set(itertools.chain(*campaign_result))

single_result = await asyncio.gather(
*[
call_single_threat_endpoint(
ctx=ctx, threat_id=threat_id, semaphore=semaphore
)
for threat_id in threat_ids
]
)
messages = set(itertools.chain(*single_result))

for message in messages:
record = (MAP_RESOURCE_TO_LOGTYPE[Resource.threats], json.loads(message))
logging.debug(f"Inserting threat message record {record}")
await output_queue.put(record)

return


async def get_cases(ctx: Context, output_queue: asyncio.Queue) -> asyncio.Queue:
intervals = compute_intervals(ctx)
logging.info(
"Computed cases intervals\n"
+ "\n".join(map(lambda x: f"{str(x.start)} : {str(x.end)}", intervals))
)

assert len(intervals) <= 5, "Intervals more than 5"
semaphore = asyncio.Semaphore(ctx.NUM_CONCURRENCY)

result = await asyncio.gather(
*[
call_cases_endpoint(ctx=ctx, interval=interval, semaphore=semaphore)
for interval in intervals
]
)
case_ids = set(itertools.chain(*result))

cases = await asyncio.gather(
*[
call_single_case_endpoint(ctx=ctx, case_id=case_id, semaphore=semaphore)
for case_id in case_ids
]
)

for case in cases:
loaded_case = json.loads(case)
record = (MAP_RESOURCE_TO_LOGTYPE[Resource.cases], loaded_case)
visible_time = try_str_to_datetime(loaded_case["customerVisibleTime"])
if visible_time >= ctx.CLIENT_FILTER_TIME_RANGE.start and visible_time < ctx.CLIENT_FILTER_TIME_RANGE.end:
logging.debug(f"Inserting case record {record}")
await output_queue.put(record)
else:
logging.debug(f"Skipping case record {record}")

return


Loading

0 comments on commit 6021051

Please sign in to comment.