Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup rule survey code #1923

Merged
merged 12 commits into from
Sep 6, 2022
Merged
12 changes: 9 additions & 3 deletions detection_rules/devtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,7 @@ def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_coun
"""Survey rule counts."""
from kibana.resources import Signal
from .main import search_rules
# from .eswrap import parse_unique_field_results
brokensound77 marked this conversation as resolved.
Show resolved Hide resolved

survey_results = []
start_time, end_time = date_range
Expand All @@ -1039,15 +1040,20 @@ def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_coun
click.echo(f'Saving detailed dump to: {dump_file}')
brokensound77 marked this conversation as resolved.
Show resolved Hide resolved

collector = CollectEvents(elasticsearch_client)
details = collector.search_from_rule(*rules, start_time=start_time, end_time=end_time)
counts = collector.count_from_rule(*rules, start_time=start_time, end_time=end_time)
details = collector.search_from_rule(rules, start_time=start_time, end_time=end_time)
counts = collector.count_from_rule(rules, start_time=start_time, end_time=end_time)

# add alerts
with kibana_client:
range_dsl = {'query': {'bool': {'filter': []}}}
add_range_to_dsl(range_dsl['query']['bool']['filter'], start_time, end_time)
alerts = {a['_source']['signal']['rule']['rule_id']: a['_source']
for a in Signal.search(range_dsl)['hits']['hits']}
for a in Signal.search(range_dsl, size=10000)['hits']['hits']}

# for alert in alerts:
# rule_id = alert['signal']['rule']['rule_id']
# rule = rules.id_map[rule_id]
# unique_results = parse_unique_field_results(rule.contents.data.type, rule.contents.data.unique_fields, alert)
brokensound77 marked this conversation as resolved.
Show resolved Hide resolved

for rule_id, count in counts.items():
alert_count = len(alerts.get(rule_id, []))
Expand Down
15 changes: 0 additions & 15 deletions detection_rules/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,6 @@ def add_field(schema, name, info):
add_field(schema, remaining, info)


def nest_from_dot(dots, value):
"""Nest a dotted field and set the inner most value."""
fields = dots.split('.')

if not fields:
return {}

nested = {fields.pop(): value}

for field in reversed(fields):
nested = {field: nested}

return nested


def _recursive_merge(existing, new, depth=0):
"""Return an existing dict merged into a new one."""
for key, value in existing.items():
Expand Down
114 changes: 63 additions & 51 deletions detection_rules/eswrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import time
from collections import defaultdict
from typing import Union
from typing import List, Union

import click
import elasticsearch
Expand All @@ -17,7 +17,7 @@

import kql
from .main import root
from .misc import add_params, client_error, elasticsearch_options, get_elasticsearch_client
from .misc import add_params, client_error, elasticsearch_options, get_elasticsearch_client, nested_get
from .rule import TOMLRule
from .rule_loader import rta_mappings, RuleCollection
from .utils import format_command_options, normalize_timing_and_sort, unix_time_to_formatted, get_path
Expand All @@ -33,7 +33,23 @@ def add_range_to_dsl(dsl_filter, start_time, end_time='now'):
)


class RtaEvents(object):
def parse_unique_field_results(rule_type: str, unique_fields: List[str], search_results: dict):
parsed_results = defaultdict(lambda: defaultdict(int))
hits = search_results['hits']
hits = hits['hits'] if rule_type != 'eql' else hits.get('events') or hits.get('sequences', [])
for hit in hits:
for field in unique_fields:
brokensound77 marked this conversation as resolved.
Show resolved Hide resolved
match = nested_get(hit['_source'], field)
if not match:
continue

match = ','.join(sorted(match)) if isinstance(match, list) else match
parsed_results[field][match] += 1
# if rule.type == eql, structure is different
return {'results': parsed_results} if parsed_results else {}


class RtaEvents:
"""Events collected from Elasticsearch."""

def __init__(self, events):
Expand Down Expand Up @@ -64,7 +80,7 @@ def evaluate_against_rule_and_update_mapping(self, rule_id, rta_name, verbose=Tr
"""Evaluate a rule against collected events and update mapping."""
from .utils import combine_sources, evaluate

rule = next((rule for rule in RuleCollection.default() if rule.id == rule_id), None)
rule = RuleCollection.default().id_map.get(rule_id)
assert rule is not None, f"Unable to find rule with ID {rule_id}"
merged_events = combine_sources(*self.events.values())
filtered = evaluate(rule, merged_events)
Expand Down Expand Up @@ -112,7 +128,7 @@ def _build_timestamp_map(self, index_str):

def _get_last_event_time(self, index_str, dsl=None):
"""Get timestamp of most recent event."""
last_event = self.client.search(dsl, index_str, size=1, sort='@timestamp:desc')['hits']['hits']
last_event = self.client.search(query=dsl, index=index_str, size=1, sort='@timestamp:desc')['hits']['hits']
if not last_event:
return

Expand Down Expand Up @@ -146,7 +162,7 @@ def _prep_query(query, language, index, start_time=None, end_time=None):
elif language == 'dsl':
formatted_dsl = {'query': query}
else:
raise ValueError('Unknown search language')
raise ValueError(f'Unknown search language: {language}')

if start_time or end_time:
end_time = end_time or 'now'
Expand All @@ -172,84 +188,78 @@ def search(self, query, language, index: Union[str, list] = '*', start_time=None

return results

def search_from_rule(self, *rules: TOMLRule, start_time=None, end_time='now', size=None):
def search_from_rule(self, rules: RuleCollection, start_time=None, end_time='now', size=None):
"""Search an elasticsearch instance using a rule."""
from .misc import nested_get

async_client = AsyncSearchClient(self.client)
survey_results = {}

def parse_unique_field_results(rule_type, unique_fields, search_results):
parsed_results = defaultdict(lambda: defaultdict(int))
hits = search_results['hits']
hits = hits['hits'] if rule_type != 'eql' else hits.get('events') or hits.get('sequences', [])
for hit in hits:
for field in unique_fields:
match = nested_get(hit['_source'], field)
match = ','.join(sorted(match)) if isinstance(match, list) else match
parsed_results[field][match] += 1
# if rule.type == eql, structure is different
return {'results': parsed_results} if parsed_results else {}

multi_search = []
multi_search_rules = []
async_searches = {}
eql_searches = {}
async_searches = []
eql_searches = []

for rule in rules:
if not rule.query:
if not rule.contents.data.get('query'):
continue

index_str, formatted_dsl, lucene_query = self._prep_query(query=rule.query,
language=rule.contents.get('language'),
index=rule.contents.get('index', '*'),
language = rule.contents.data.get('language')
query = rule.contents.data.query
rule_type = rule.contents.data.type
index_str, formatted_dsl, lucene_query = self._prep_query(query=query,
language=language,
index=rule.contents.data.get('index', '*'),
start_time=start_time,
end_time=end_time)
formatted_dsl.update(size=size or self.max_events)

# prep for searches: msearch for kql | async search for lucene | eql client search for eql
if rule.contents['language'] == 'kuery':
if language == 'kuery':
multi_search_rules.append(rule)
multi_search.append(json.dumps(
{'index': index_str, 'allow_no_indices': 'true', 'ignore_unavailable': 'true'}))
multi_search.append(json.dumps(formatted_dsl))
elif rule.contents['language'] == 'lucene':
multi_search.append({'index': index_str, 'allow_no_indices': 'true', 'ignore_unavailable': 'true'})
multi_search.append(formatted_dsl)
elif language == 'lucene':
# wait for 0 to try and force async with no immediate results (not guaranteed)
result = async_client.submit(body=formatted_dsl, q=rule.query, index=index_str,
result = async_client.submit(body=formatted_dsl, q=query, index=index_str,
allow_no_indices=True, ignore_unavailable=True,
wait_for_completion_timeout=0)
if result['is_running'] is True:
async_searches[rule] = result['id']
async_searches.append((rule, result['id']))
else:
survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields,
survey_results[rule.id] = parse_unique_field_results(rule_type, ['process.name'],
result['response'])
elif rule.contents['language'] == 'eql':
elif language == 'eql':
eql_body = {
'index': index_str,
'params': {'ignore_unavailable': 'true', 'allow_no_indices': 'true'},
'body': {'query': rule.query, 'filter': formatted_dsl['filter']}
'body': {'query': query, 'filter': formatted_dsl['filter']}
}
eql_searches[rule] = eql_body
eql_searches.append((rule, eql_body))

# assemble search results
multi_search_results = self.client.msearch('\n'.join(multi_search) + '\n')
multi_search_results = self.client.msearch(searches=multi_search)
for index, result in enumerate(multi_search_results['responses']):
try:
rule = multi_search_rules[index]
survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields, result)
survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type,
rule.contents.data.unique_fields, result)
except KeyError:
survey_results[multi_search_rules[index].id] = {'error_retrieving_results': True}

for rule, search_args in eql_searches.items():
for entry in eql_searches:
rule: TOMLRule
search_args: dict
rule, search_args = entry
try:
result = self.client.eql.search(**search_args)
survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields, result)
survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type,
rule.contents.data.unique_fields, result)
except (elasticsearch.NotFoundError, elasticsearch.RequestError) as e:
survey_results[rule.id] = {'error_retrieving_results': True, 'error': e.info['error']['reason']}

for rule, async_id in async_searches.items():
result = async_client.get(async_id)['response']
survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields, result)
for entry in async_searches:
rule: TOMLRule
rule, async_id = entry
result = async_client.get(id=async_id)['response']
survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, ['process.name'], result)

return survey_results

Expand All @@ -267,19 +277,21 @@ def count(self, query, language, index: Union[str, list], start_time=None, end_t
return self.client.count(body=formatted_dsl, index=index_str, q=lucene_query, allow_no_indices=True,
ignore_unavailable=True)['count']

def count_from_rule(self, *rules, start_time=None, end_time='now'):
def count_from_rule(self, rules: RuleCollection, start_time=None, end_time='now'):
"""Get a count of documents from elasticsearch using a rule."""
survey_results = {}

for rule in rules:
for rule in rules.rules:
rule_results = {'rule_id': rule.id, 'name': rule.name}

if not rule.query:
if not rule.contents.data.get('query'):
continue

try:
rule_results['search_count'] = self.count(query=rule.query, language=rule.contents.get('language'),
index=rule.contents.get('index', '*'), start_time=start_time,
rule_results['search_count'] = self.count(query=rule.contents.data.query,
language=rule.contents.data.language,
index=rule.contents.data.get('index', '*'),
start_time=start_time,
end_time=end_time)
except (elasticsearch.NotFoundError, elasticsearch.RequestError):
rule_results['search_count'] = -1
Expand Down
38 changes: 27 additions & 11 deletions detection_rules/kbwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from kibana import Signal, RuleResource
from .cli_utils import multi_collection
from .main import root
from .misc import add_params, client_error, kibana_options, get_kibana_client
from .misc import add_params, client_error, kibana_options, get_kibana_client, nested_set
from .schemas import downgrade
from .utils import format_command_options

Expand Down Expand Up @@ -82,8 +82,9 @@ def upload_rule(ctx, rules, replace_id):
@click.option('--date-range', '-d', type=(str, str), default=('now-7d', 'now'), help='Date range to scope search')
@click.option('--columns', '-c', multiple=True, help='Columns to display in table')
@click.option('--extend', '-e', is_flag=True, help='If columns are specified, extend the original columns')
@click.option('--max-count', '-m', default=100, help='The max number of alerts to return')
@click.pass_context
def search_alerts(ctx, query, date_range, columns, extend):
def search_alerts(ctx, query, date_range, columns, extend, max_count):
"""Search detection engine alerts with KQL."""
from eql.table import Table
from .eswrap import MATCH_ALL, add_range_to_dsl
Expand All @@ -94,15 +95,30 @@ def search_alerts(ctx, query, date_range, columns, extend):
add_range_to_dsl(kql_query['bool'].setdefault('filter', []), start_time, end_time)

with kibana:
alerts = [a['_source'] for a in Signal.search({'query': kql_query})['hits']['hits']]

table_columns = ['host.hostname', 'rule.name', '@timestamp']
alerts = [a['_source'] for a in Signal.search({'query': kql_query}, size=max_count)['hits']['hits']]

# check for events with nested signal fields
if alerts and 'signal' in alerts[0]:
table_columns = ['host.hostname', 'signal.rule.name', 'signal.status', 'signal.original_time']
if columns:
columns = list(columns)
table_columns = table_columns + columns if extend else columns
click.echo(Table.from_list(table_columns, alerts))
if alerts:
table_columns = ['host.hostname']

if 'signal' in alerts[0]:
table_columns += ['signal.rule.name', 'signal.status', 'signal.original_time']
elif 'kibana.alert.rule.name' in alerts[0]:
table_columns += ['kibana.alert.rule.name', 'kibana.alert.status', 'kibana.alert.original_time']
else:
table_columns += ['rule.name', '@timestamp']
if columns:
columns = list(columns)
table_columns = table_columns + columns if extend else columns

# Table requires the data to be nested, but depending on the version, some data uses dotted keys, so
# they must be nested explicitly
for alert in alerts:
for key in table_columns:
if key in alert:
nested_set(alert, key, alert[key])

click.echo(Table.from_list(table_columns, alerts))
else:
click.echo('No alerts detected')
return alerts
4 changes: 2 additions & 2 deletions detection_rules/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
RTA_DIR = get_path("rta")


class RtaMappings(object):
class RtaMappings:
"""Rta-mapping helper class."""

def __init__(self):
"""Rta-mapping validation and prep."""
self.mapping = load_etc_dump('rule-mapping.yml') # type: dict
self.mapping: dict = load_etc_dump('rule-mapping.yml')
self.validate()

self._rta_mapping = defaultdict(list)
Expand Down
17 changes: 16 additions & 1 deletion detection_rules/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def nested_get(_dict, dot_key, default=None):


def nested_set(_dict, dot_key, value):
"""Set a nested field from a a key in dot notation."""
"""Set a nested field from a key in dot notation."""
keys = dot_key.split('.')
for key in keys[:-1]:
_dict = _dict.setdefault(key, {})
Expand All @@ -100,6 +100,21 @@ def nested_set(_dict, dot_key, value):
raise ValueError('dict cannot set a value to a non-dict for {}'.format(dot_key))


def nest_from_dot(dots, value):
"""Nest a dotted field and set the innermost value."""
fields = dots.split('.')

if not fields:
return {}

nested = {fields.pop(): value}

for field in reversed(fields):
nested = {field: nested}

return nested


def schema_prompt(name, value=None, required=False, **options):
"""Interactively prompt based on schema requirements."""
name = str(name)
Expand Down
2 changes: 1 addition & 1 deletion detection_rules/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class QueryValidator:

@property
def ast(self) -> Any:
raise NotImplementedError
raise NotImplementedError()

@property
def unique_fields(self) -> Any:
Expand Down
3 changes: 3 additions & 0 deletions kibana/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(self, cloud_id=None, kibana_url=None, verify=True, elasticsearch=No
self.domain, self.es_uuid, self.kibana_uuid = \
base64.b64decode(cloud_info.encode("utf-8")).decode("utf-8").split("$")

if self.domain.endswith(':443'):
self.domain = self.domain[:-4]

kibana_url_from_cloud = f"https://{self.kibana_uuid}.{self.domain}:9243"
if self.kibana_url and self.kibana_url != kibana_url_from_cloud:
raise ValueError(f'kibana_url provided ({self.kibana_url}) does not match url derived from cloud_id '
Expand Down
Loading