Skip to content

Commit

Permalink
Cleanup rule survey code (#1923)
Browse files Browse the repository at this point in the history
* Cleanup rule survey code

* default to only unique-ing on process name for lucene rules

* fix bug in kibana url parsing by removing redundant port from domain

* update search-alerts columns and nest fields

* fix rule.contents.data.index

Co-authored-by: Mika Ayenson <Mikaayenson@users.noreply.github.com>

(cherry picked from commit 332ea40)
  • Loading branch information
brokensound77 authored and github-actions[bot] committed Sep 7, 2022
1 parent 5a6d953 commit dd93603
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 89 deletions.
12 changes: 9 additions & 3 deletions detection_rules/devtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,7 @@ def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_coun
"""Survey rule counts."""
from kibana.resources import Signal
from .main import search_rules
# from .eswrap import parse_unique_field_results

survey_results = []
start_time, end_time = date_range
Expand All @@ -1039,15 +1040,20 @@ def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_coun
click.echo(f'Saving detailed dump to: {dump_file}')

collector = CollectEvents(elasticsearch_client)
details = collector.search_from_rule(*rules, start_time=start_time, end_time=end_time)
counts = collector.count_from_rule(*rules, start_time=start_time, end_time=end_time)
details = collector.search_from_rule(rules, start_time=start_time, end_time=end_time)
counts = collector.count_from_rule(rules, start_time=start_time, end_time=end_time)

# add alerts
with kibana_client:
range_dsl = {'query': {'bool': {'filter': []}}}
add_range_to_dsl(range_dsl['query']['bool']['filter'], start_time, end_time)
alerts = {a['_source']['signal']['rule']['rule_id']: a['_source']
for a in Signal.search(range_dsl)['hits']['hits']}
for a in Signal.search(range_dsl, size=10000)['hits']['hits']}

# for alert in alerts:
# rule_id = alert['signal']['rule']['rule_id']
# rule = rules.id_map[rule_id]
# unique_results = parse_unique_field_results(rule.contents.data.type, rule.contents.data.unique_fields, alert)

for rule_id, count in counts.items():
alert_count = len(alerts.get(rule_id, []))
Expand Down
15 changes: 0 additions & 15 deletions detection_rules/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,6 @@ def add_field(schema, name, info):
add_field(schema, remaining, info)


def nest_from_dot(dots, value):
"""Nest a dotted field and set the inner most value."""
fields = dots.split('.')

if not fields:
return {}

nested = {fields.pop(): value}

for field in reversed(fields):
nested = {field: nested}

return nested


def _recursive_merge(existing, new, depth=0):
"""Return an existing dict merged into a new one."""
for key, value in existing.items():
Expand Down
114 changes: 63 additions & 51 deletions detection_rules/eswrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import time
from collections import defaultdict
from typing import Union
from typing import List, Union

import click
import elasticsearch
Expand All @@ -17,7 +17,7 @@

import kql
from .main import root
from .misc import add_params, client_error, elasticsearch_options, get_elasticsearch_client
from .misc import add_params, client_error, elasticsearch_options, get_elasticsearch_client, nested_get
from .rule import TOMLRule
from .rule_loader import rta_mappings, RuleCollection
from .utils import format_command_options, normalize_timing_and_sort, unix_time_to_formatted, get_path
Expand All @@ -33,7 +33,23 @@ def add_range_to_dsl(dsl_filter, start_time, end_time='now'):
)


class RtaEvents(object):
def parse_unique_field_results(rule_type: str, unique_fields: List[str], search_results: dict):
parsed_results = defaultdict(lambda: defaultdict(int))
hits = search_results['hits']
hits = hits['hits'] if rule_type != 'eql' else hits.get('events') or hits.get('sequences', [])
for hit in hits:
for field in unique_fields:
match = nested_get(hit['_source'], field)
if not match:
continue

match = ','.join(sorted(match)) if isinstance(match, list) else match
parsed_results[field][match] += 1
# if rule.type == eql, structure is different
return {'results': parsed_results} if parsed_results else {}


class RtaEvents:
"""Events collected from Elasticsearch."""

def __init__(self, events):
Expand Down Expand Up @@ -64,7 +80,7 @@ def evaluate_against_rule_and_update_mapping(self, rule_id, rta_name, verbose=Tr
"""Evaluate a rule against collected events and update mapping."""
from .utils import combine_sources, evaluate

rule = next((rule for rule in RuleCollection.default() if rule.id == rule_id), None)
rule = RuleCollection.default().id_map.get(rule_id)
assert rule is not None, f"Unable to find rule with ID {rule_id}"
merged_events = combine_sources(*self.events.values())
filtered = evaluate(rule, merged_events)
Expand Down Expand Up @@ -112,7 +128,7 @@ def _build_timestamp_map(self, index_str):

def _get_last_event_time(self, index_str, dsl=None):
"""Get timestamp of most recent event."""
last_event = self.client.search(dsl, index_str, size=1, sort='@timestamp:desc')['hits']['hits']
last_event = self.client.search(query=dsl, index=index_str, size=1, sort='@timestamp:desc')['hits']['hits']
if not last_event:
return

Expand Down Expand Up @@ -146,7 +162,7 @@ def _prep_query(query, language, index, start_time=None, end_time=None):
elif language == 'dsl':
formatted_dsl = {'query': query}
else:
raise ValueError('Unknown search language')
raise ValueError(f'Unknown search language: {language}')

if start_time or end_time:
end_time = end_time or 'now'
Expand All @@ -172,84 +188,78 @@ def search(self, query, language, index: Union[str, list] = '*', start_time=None

return results

def search_from_rule(self, *rules: TOMLRule, start_time=None, end_time='now', size=None):
def search_from_rule(self, rules: RuleCollection, start_time=None, end_time='now', size=None):
"""Search an elasticsearch instance using a rule."""
from .misc import nested_get

async_client = AsyncSearchClient(self.client)
survey_results = {}

def parse_unique_field_results(rule_type, unique_fields, search_results):
parsed_results = defaultdict(lambda: defaultdict(int))
hits = search_results['hits']
hits = hits['hits'] if rule_type != 'eql' else hits.get('events') or hits.get('sequences', [])
for hit in hits:
for field in unique_fields:
match = nested_get(hit['_source'], field)
match = ','.join(sorted(match)) if isinstance(match, list) else match
parsed_results[field][match] += 1
# if rule.type == eql, structure is different
return {'results': parsed_results} if parsed_results else {}

multi_search = []
multi_search_rules = []
async_searches = {}
eql_searches = {}
async_searches = []
eql_searches = []

for rule in rules:
if not rule.query:
if not rule.contents.data.get('query'):
continue

index_str, formatted_dsl, lucene_query = self._prep_query(query=rule.query,
language=rule.contents.get('language'),
index=rule.contents.get('index', '*'),
language = rule.contents.data.get('language')
query = rule.contents.data.query
rule_type = rule.contents.data.type
index_str, formatted_dsl, lucene_query = self._prep_query(query=query,
language=language,
index=rule.contents.data.get('index', '*'),
start_time=start_time,
end_time=end_time)
formatted_dsl.update(size=size or self.max_events)

# prep for searches: msearch for kql | async search for lucene | eql client search for eql
if rule.contents['language'] == 'kuery':
if language == 'kuery':
multi_search_rules.append(rule)
multi_search.append(json.dumps(
{'index': index_str, 'allow_no_indices': 'true', 'ignore_unavailable': 'true'}))
multi_search.append(json.dumps(formatted_dsl))
elif rule.contents['language'] == 'lucene':
multi_search.append({'index': index_str, 'allow_no_indices': 'true', 'ignore_unavailable': 'true'})
multi_search.append(formatted_dsl)
elif language == 'lucene':
# wait for 0 to try and force async with no immediate results (not guaranteed)
result = async_client.submit(body=formatted_dsl, q=rule.query, index=index_str,
result = async_client.submit(body=formatted_dsl, q=query, index=index_str,
allow_no_indices=True, ignore_unavailable=True,
wait_for_completion_timeout=0)
if result['is_running'] is True:
async_searches[rule] = result['id']
async_searches.append((rule, result['id']))
else:
survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields,
survey_results[rule.id] = parse_unique_field_results(rule_type, ['process.name'],
result['response'])
elif rule.contents['language'] == 'eql':
elif language == 'eql':
eql_body = {
'index': index_str,
'params': {'ignore_unavailable': 'true', 'allow_no_indices': 'true'},
'body': {'query': rule.query, 'filter': formatted_dsl['filter']}
'body': {'query': query, 'filter': formatted_dsl['filter']}
}
eql_searches[rule] = eql_body
eql_searches.append((rule, eql_body))

# assemble search results
multi_search_results = self.client.msearch('\n'.join(multi_search) + '\n')
multi_search_results = self.client.msearch(searches=multi_search)
for index, result in enumerate(multi_search_results['responses']):
try:
rule = multi_search_rules[index]
survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields, result)
survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type,
rule.contents.data.unique_fields, result)
except KeyError:
survey_results[multi_search_rules[index].id] = {'error_retrieving_results': True}

for rule, search_args in eql_searches.items():
for entry in eql_searches:
rule: TOMLRule
search_args: dict
rule, search_args = entry
try:
result = self.client.eql.search(**search_args)
survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields, result)
survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type,
rule.contents.data.unique_fields, result)
except (elasticsearch.NotFoundError, elasticsearch.RequestError) as e:
survey_results[rule.id] = {'error_retrieving_results': True, 'error': e.info['error']['reason']}

for rule, async_id in async_searches.items():
result = async_client.get(async_id)['response']
survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields, result)
for entry in async_searches:
rule: TOMLRule
rule, async_id = entry
result = async_client.get(id=async_id)['response']
survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, ['process.name'], result)

return survey_results

Expand All @@ -267,19 +277,21 @@ def count(self, query, language, index: Union[str, list], start_time=None, end_t
return self.client.count(body=formatted_dsl, index=index_str, q=lucene_query, allow_no_indices=True,
ignore_unavailable=True)['count']

def count_from_rule(self, *rules, start_time=None, end_time='now'):
def count_from_rule(self, rules: RuleCollection, start_time=None, end_time='now'):
"""Get a count of documents from elasticsearch using a rule."""
survey_results = {}

for rule in rules:
for rule in rules.rules:
rule_results = {'rule_id': rule.id, 'name': rule.name}

if not rule.query:
if not rule.contents.data.get('query'):
continue

try:
rule_results['search_count'] = self.count(query=rule.query, language=rule.contents.get('language'),
index=rule.contents.get('index', '*'), start_time=start_time,
rule_results['search_count'] = self.count(query=rule.contents.data.query,
language=rule.contents.data.language,
index=rule.contents.data.get('index', '*'),
start_time=start_time,
end_time=end_time)
except (elasticsearch.NotFoundError, elasticsearch.RequestError):
rule_results['search_count'] = -1
Expand Down
38 changes: 27 additions & 11 deletions detection_rules/kbwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from kibana import Signal, RuleResource
from .cli_utils import multi_collection
from .main import root
from .misc import add_params, client_error, kibana_options, get_kibana_client
from .misc import add_params, client_error, kibana_options, get_kibana_client, nested_set
from .schemas import downgrade
from .utils import format_command_options

Expand Down Expand Up @@ -82,8 +82,9 @@ def upload_rule(ctx, rules, replace_id):
@click.option('--date-range', '-d', type=(str, str), default=('now-7d', 'now'), help='Date range to scope search')
@click.option('--columns', '-c', multiple=True, help='Columns to display in table')
@click.option('--extend', '-e', is_flag=True, help='If columns are specified, extend the original columns')
@click.option('--max-count', '-m', default=100, help='The max number of alerts to return')
@click.pass_context
def search_alerts(ctx, query, date_range, columns, extend):
def search_alerts(ctx, query, date_range, columns, extend, max_count):
"""Search detection engine alerts with KQL."""
from eql.table import Table
from .eswrap import MATCH_ALL, add_range_to_dsl
Expand All @@ -94,15 +95,30 @@ def search_alerts(ctx, query, date_range, columns, extend):
add_range_to_dsl(kql_query['bool'].setdefault('filter', []), start_time, end_time)

with kibana:
alerts = [a['_source'] for a in Signal.search({'query': kql_query})['hits']['hits']]

table_columns = ['host.hostname', 'rule.name', '@timestamp']
alerts = [a['_source'] for a in Signal.search({'query': kql_query}, size=max_count)['hits']['hits']]

# check for events with nested signal fields
if alerts and 'signal' in alerts[0]:
table_columns = ['host.hostname', 'signal.rule.name', 'signal.status', 'signal.original_time']
if columns:
columns = list(columns)
table_columns = table_columns + columns if extend else columns
click.echo(Table.from_list(table_columns, alerts))
if alerts:
table_columns = ['host.hostname']

if 'signal' in alerts[0]:
table_columns += ['signal.rule.name', 'signal.status', 'signal.original_time']
elif 'kibana.alert.rule.name' in alerts[0]:
table_columns += ['kibana.alert.rule.name', 'kibana.alert.status', 'kibana.alert.original_time']
else:
table_columns += ['rule.name', '@timestamp']
if columns:
columns = list(columns)
table_columns = table_columns + columns if extend else columns

# Table requires the data to be nested, but depending on the version, some data uses dotted keys, so
# they must be nested explicitly
for alert in alerts:
for key in table_columns:
if key in alert:
nested_set(alert, key, alert[key])

click.echo(Table.from_list(table_columns, alerts))
else:
click.echo('No alerts detected')
return alerts
4 changes: 2 additions & 2 deletions detection_rules/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
RTA_DIR = get_path("rta")


class RtaMappings(object):
class RtaMappings:
"""Rta-mapping helper class."""

def __init__(self):
"""Rta-mapping validation and prep."""
self.mapping = load_etc_dump('rule-mapping.yml') # type: dict
self.mapping: dict = load_etc_dump('rule-mapping.yml')
self.validate()

self._rta_mapping = defaultdict(list)
Expand Down
17 changes: 16 additions & 1 deletion detection_rules/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def nested_get(_dict, dot_key, default=None):


def nested_set(_dict, dot_key, value):
"""Set a nested field from a a key in dot notation."""
"""Set a nested field from a key in dot notation."""
keys = dot_key.split('.')
for key in keys[:-1]:
_dict = _dict.setdefault(key, {})
Expand All @@ -100,6 +100,21 @@ def nested_set(_dict, dot_key, value):
raise ValueError('dict cannot set a value to a non-dict for {}'.format(dot_key))


def nest_from_dot(dots, value):
"""Nest a dotted field and set the innermost value."""
fields = dots.split('.')

if not fields:
return {}

nested = {fields.pop(): value}

for field in reversed(fields):
nested = {field: nested}

return nested


def schema_prompt(name, value=None, required=False, **options):
"""Interactively prompt based on schema requirements."""
name = str(name)
Expand Down
2 changes: 1 addition & 1 deletion detection_rules/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class QueryValidator:

@property
def ast(self) -> Any:
raise NotImplementedError
raise NotImplementedError()

@property
def unique_fields(self) -> Any:
Expand Down
3 changes: 3 additions & 0 deletions kibana/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(self, cloud_id=None, kibana_url=None, verify=True, elasticsearch=No
self.domain, self.es_uuid, self.kibana_uuid = \
base64.b64decode(cloud_info.encode("utf-8")).decode("utf-8").split("$")

if self.domain.endswith(':443'):
self.domain = self.domain[:-4]

kibana_url_from_cloud = f"https://{self.kibana_uuid}.{self.domain}:9243"
if self.kibana_url and self.kibana_url != kibana_url_from_cloud:
raise ValueError(f'kibana_url provided ({self.kibana_url}) does not match url derived from cloud_id '
Expand Down
Loading

0 comments on commit dd93603

Please sign in to comment.