From cf4a776731fff7442f8b0356180e6653087019f2 Mon Sep 17 00:00:00 2001 From: Justin Ibarra Date: Tue, 6 Sep 2022 15:53:47 -0600 Subject: [PATCH] Cleanup rule survey code (#1923) * Cleanup rule survey code * default to only unique-ing on process name for lucene rules * fix bug in kibana url parsing by removing redundant port from domain * update search-alerts columns and nest fields * fix rule.contents.data.index Co-authored-by: Mika Ayenson (cherry picked from commit 332ea401004e436938f1413585a85b76f2852d53) --- detection_rules/devtools.py | 12 +++- detection_rules/ecs.py | 15 ----- detection_rules/eswrap.py | 114 ++++++++++++++++++++---------------- detection_rules/kbwrap.py | 38 ++++++++---- detection_rules/mappings.py | 4 +- detection_rules/misc.py | 17 +++++- detection_rules/rule.py | 2 +- kibana/connector.py | 3 + kibana/resources.py | 11 ++-- 9 files changed, 127 insertions(+), 89 deletions(-) diff --git a/detection_rules/devtools.py b/detection_rules/devtools.py index 01460607d3e..9353de4ee35 100644 --- a/detection_rules/devtools.py +++ b/detection_rules/devtools.py @@ -1024,6 +1024,7 @@ def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_coun """Survey rule counts.""" from kibana.resources import Signal from .main import search_rules + # from .eswrap import parse_unique_field_results survey_results = [] start_time, end_time = date_range @@ -1039,15 +1040,20 @@ def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_coun click.echo(f'Saving detailed dump to: {dump_file}') collector = CollectEvents(elasticsearch_client) - details = collector.search_from_rule(*rules, start_time=start_time, end_time=end_time) - counts = collector.count_from_rule(*rules, start_time=start_time, end_time=end_time) + details = collector.search_from_rule(rules, start_time=start_time, end_time=end_time) + counts = collector.count_from_rule(rules, start_time=start_time, end_time=end_time) # add alerts with kibana_client: range_dsl = {'query': {'bool': {'filter': []}}} add_range_to_dsl(range_dsl['query']['bool']['filter'], start_time, end_time) alerts = {a['_source']['signal']['rule']['rule_id']: a['_source'] - for a in Signal.search(range_dsl)['hits']['hits']} + for a in Signal.search(range_dsl, size=10000)['hits']['hits']} + + # for alert in alerts: + # rule_id = alert['signal']['rule']['rule_id'] + # rule = rules.id_map[rule_id] + # unique_results = parse_unique_field_results(rule.contents.data.type, rule.contents.data.unique_fields, alert) for rule_id, count in counts.items(): alert_count = len(alerts.get(rule_id, [])) diff --git a/detection_rules/ecs.py b/detection_rules/ecs.py index e03555c134a..e3c7bfd0264 100644 --- a/detection_rules/ecs.py +++ b/detection_rules/ecs.py @@ -35,21 +35,6 @@ def add_field(schema, name, info): add_field(schema, remaining, info) -def nest_from_dot(dots, value): - """Nest a dotted field and set the inner most value.""" - fields = dots.split('.') - - if not fields: - return {} - - nested = {fields.pop(): value} - - for field in reversed(fields): - nested = {field: nested} - - return nested - - def _recursive_merge(existing, new, depth=0): """Return an existing dict merged into a new one.""" for key, value in existing.items(): diff --git a/detection_rules/eswrap.py b/detection_rules/eswrap.py index 8b9e722fa13..75bb05f09c0 100644 --- a/detection_rules/eswrap.py +++ b/detection_rules/eswrap.py @@ -8,7 +8,7 @@ import os import time from collections import defaultdict -from typing import Union +from typing import List, Union import click import elasticsearch @@ -17,7 +17,7 @@ import kql from .main import root -from .misc import add_params, client_error, elasticsearch_options, get_elasticsearch_client +from .misc import add_params, client_error, elasticsearch_options, get_elasticsearch_client, nested_get from .rule import TOMLRule from .rule_loader import rta_mappings, RuleCollection from .utils import format_command_options, normalize_timing_and_sort, unix_time_to_formatted, get_path @@ -33,7 +33,23 @@ def add_range_to_dsl(dsl_filter, start_time, end_time='now'): ) -class RtaEvents(object): +def parse_unique_field_results(rule_type: str, unique_fields: List[str], search_results: dict): + parsed_results = defaultdict(lambda: defaultdict(int)) + hits = search_results['hits'] + hits = hits['hits'] if rule_type != 'eql' else hits.get('events') or hits.get('sequences', []) + for hit in hits: + for field in unique_fields: + match = nested_get(hit['_source'], field) + if not match: + continue + + match = ','.join(sorted(match)) if isinstance(match, list) else match + parsed_results[field][match] += 1 + # if rule.type == eql, structure is different + return {'results': parsed_results} if parsed_results else {} + + +class RtaEvents: """Events collected from Elasticsearch.""" def __init__(self, events): @@ -64,7 +80,7 @@ def evaluate_against_rule_and_update_mapping(self, rule_id, rta_name, verbose=Tr """Evaluate a rule against collected events and update mapping.""" from .utils import combine_sources, evaluate - rule = next((rule for rule in RuleCollection.default() if rule.id == rule_id), None) + rule = RuleCollection.default().id_map.get(rule_id) assert rule is not None, f"Unable to find rule with ID {rule_id}" merged_events = combine_sources(*self.events.values()) filtered = evaluate(rule, merged_events) @@ -112,7 +128,7 @@ def _build_timestamp_map(self, index_str): def _get_last_event_time(self, index_str, dsl=None): """Get timestamp of most recent event.""" - last_event = self.client.search(dsl, index_str, size=1, sort='@timestamp:desc')['hits']['hits'] + last_event = self.client.search(query=dsl, index=index_str, size=1, sort='@timestamp:desc')['hits']['hits'] if not last_event: return @@ -146,7 +162,7 @@ def _prep_query(query, language, index, start_time=None, end_time=None): elif language == 'dsl': formatted_dsl = {'query': query} else: - raise ValueError('Unknown search language') + raise ValueError(f'Unknown search language: {language}') if start_time or end_time: end_time = end_time or 'now' @@ -172,84 +188,78 @@ def search(self, query, language, index: Union[str, list] = '*', start_time=None return results - def search_from_rule(self, *rules: TOMLRule, start_time=None, end_time='now', size=None): + def search_from_rule(self, rules: RuleCollection, start_time=None, end_time='now', size=None): """Search an elasticsearch instance using a rule.""" - from .misc import nested_get - async_client = AsyncSearchClient(self.client) survey_results = {} - - def parse_unique_field_results(rule_type, unique_fields, search_results): - parsed_results = defaultdict(lambda: defaultdict(int)) - hits = search_results['hits'] - hits = hits['hits'] if rule_type != 'eql' else hits.get('events') or hits.get('sequences', []) - for hit in hits: - for field in unique_fields: - match = nested_get(hit['_source'], field) - match = ','.join(sorted(match)) if isinstance(match, list) else match - parsed_results[field][match] += 1 - # if rule.type == eql, structure is different - return {'results': parsed_results} if parsed_results else {} - multi_search = [] multi_search_rules = [] - async_searches = {} - eql_searches = {} + async_searches = [] + eql_searches = [] for rule in rules: - if not rule.query: + if not rule.contents.data.get('query'): continue - index_str, formatted_dsl, lucene_query = self._prep_query(query=rule.query, - language=rule.contents.get('language'), - index=rule.contents.get('index', '*'), + language = rule.contents.data.get('language') + query = rule.contents.data.query + rule_type = rule.contents.data.type + index_str, formatted_dsl, lucene_query = self._prep_query(query=query, + language=language, + index=rule.contents.data.get('index', '*'), start_time=start_time, end_time=end_time) formatted_dsl.update(size=size or self.max_events) # prep for searches: msearch for kql | async search for lucene | eql client search for eql - if rule.contents['language'] == 'kuery': + if language == 'kuery': multi_search_rules.append(rule) - multi_search.append(json.dumps( - {'index': index_str, 'allow_no_indices': 'true', 'ignore_unavailable': 'true'})) - multi_search.append(json.dumps(formatted_dsl)) - elif rule.contents['language'] == 'lucene': + multi_search.append({'index': index_str, 'allow_no_indices': 'true', 'ignore_unavailable': 'true'}) + multi_search.append(formatted_dsl) + elif language == 'lucene': # wait for 0 to try and force async with no immediate results (not guaranteed) - result = async_client.submit(body=formatted_dsl, q=rule.query, index=index_str, + result = async_client.submit(body=formatted_dsl, q=query, index=index_str, allow_no_indices=True, ignore_unavailable=True, wait_for_completion_timeout=0) if result['is_running'] is True: - async_searches[rule] = result['id'] + async_searches.append((rule, result['id'])) else: - survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields, + survey_results[rule.id] = parse_unique_field_results(rule_type, ['process.name'], result['response']) - elif rule.contents['language'] == 'eql': + elif language == 'eql': eql_body = { 'index': index_str, 'params': {'ignore_unavailable': 'true', 'allow_no_indices': 'true'}, - 'body': {'query': rule.query, 'filter': formatted_dsl['filter']} + 'body': {'query': query, 'filter': formatted_dsl['filter']} } - eql_searches[rule] = eql_body + eql_searches.append((rule, eql_body)) # assemble search results - multi_search_results = self.client.msearch('\n'.join(multi_search) + '\n') + multi_search_results = self.client.msearch(searches=multi_search) for index, result in enumerate(multi_search_results['responses']): try: rule = multi_search_rules[index] - survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields, result) + survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, + rule.contents.data.unique_fields, result) except KeyError: survey_results[multi_search_rules[index].id] = {'error_retrieving_results': True} - for rule, search_args in eql_searches.items(): + for entry in eql_searches: + rule: TOMLRule + search_args: dict + rule, search_args = entry try: result = self.client.eql.search(**search_args) - survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields, result) + survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, + rule.contents.data.unique_fields, result) except (elasticsearch.NotFoundError, elasticsearch.RequestError) as e: survey_results[rule.id] = {'error_retrieving_results': True, 'error': e.info['error']['reason']} - for rule, async_id in async_searches.items(): - result = async_client.get(async_id)['response'] - survey_results[rule.id] = parse_unique_field_results(rule.type, rule.unique_fields, result) + for entry in async_searches: + rule: TOMLRule + rule, async_id = entry + result = async_client.get(id=async_id)['response'] + survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, ['process.name'], result) return survey_results @@ -267,19 +277,21 @@ def count(self, query, language, index: Union[str, list], start_time=None, end_t return self.client.count(body=formatted_dsl, index=index_str, q=lucene_query, allow_no_indices=True, ignore_unavailable=True)['count'] - def count_from_rule(self, *rules, start_time=None, end_time='now'): + def count_from_rule(self, rules: RuleCollection, start_time=None, end_time='now'): """Get a count of documents from elasticsearch using a rule.""" survey_results = {} - for rule in rules: + for rule in rules.rules: rule_results = {'rule_id': rule.id, 'name': rule.name} - if not rule.query: + if not rule.contents.data.get('query'): continue try: - rule_results['search_count'] = self.count(query=rule.query, language=rule.contents.get('language'), - index=rule.contents.get('index', '*'), start_time=start_time, + rule_results['search_count'] = self.count(query=rule.contents.data.query, + language=rule.contents.data.language, + index=rule.contents.data.get('index', '*'), + start_time=start_time, end_time=end_time) except (elasticsearch.NotFoundError, elasticsearch.RequestError): rule_results['search_count'] = -1 diff --git a/detection_rules/kbwrap.py b/detection_rules/kbwrap.py index b13f7d04609..ae1b63bc765 100644 --- a/detection_rules/kbwrap.py +++ b/detection_rules/kbwrap.py @@ -12,7 +12,7 @@ from kibana import Signal, RuleResource from .cli_utils import multi_collection from .main import root -from .misc import add_params, client_error, kibana_options, get_kibana_client +from .misc import add_params, client_error, kibana_options, get_kibana_client, nested_set from .schemas import downgrade from .utils import format_command_options @@ -82,8 +82,9 @@ def upload_rule(ctx, rules, replace_id): @click.option('--date-range', '-d', type=(str, str), default=('now-7d', 'now'), help='Date range to scope search') @click.option('--columns', '-c', multiple=True, help='Columns to display in table') @click.option('--extend', '-e', is_flag=True, help='If columns are specified, extend the original columns') +@click.option('--max-count', '-m', default=100, help='The max number of alerts to return') @click.pass_context -def search_alerts(ctx, query, date_range, columns, extend): +def search_alerts(ctx, query, date_range, columns, extend, max_count): """Search detection engine alerts with KQL.""" from eql.table import Table from .eswrap import MATCH_ALL, add_range_to_dsl @@ -94,15 +95,30 @@ def search_alerts(ctx, query, date_range, columns, extend): add_range_to_dsl(kql_query['bool'].setdefault('filter', []), start_time, end_time) with kibana: - alerts = [a['_source'] for a in Signal.search({'query': kql_query})['hits']['hits']] - - table_columns = ['host.hostname', 'rule.name', '@timestamp'] + alerts = [a['_source'] for a in Signal.search({'query': kql_query}, size=max_count)['hits']['hits']] # check for events with nested signal fields - if alerts and 'signal' in alerts[0]: - table_columns = ['host.hostname', 'signal.rule.name', 'signal.status', 'signal.original_time'] - if columns: - columns = list(columns) - table_columns = table_columns + columns if extend else columns - click.echo(Table.from_list(table_columns, alerts)) + if alerts: + table_columns = ['host.hostname'] + + if 'signal' in alerts[0]: + table_columns += ['signal.rule.name', 'signal.status', 'signal.original_time'] + elif 'kibana.alert.rule.name' in alerts[0]: + table_columns += ['kibana.alert.rule.name', 'kibana.alert.status', 'kibana.alert.original_time'] + else: + table_columns += ['rule.name', '@timestamp'] + if columns: + columns = list(columns) + table_columns = table_columns + columns if extend else columns + + # Table requires the data to be nested, but depending on the version, some data uses dotted keys, so + # they must be nested explicitly + for alert in alerts: + for key in table_columns: + if key in alert: + nested_set(alert, key, alert[key]) + + click.echo(Table.from_list(table_columns, alerts)) + else: + click.echo('No alerts detected') return alerts diff --git a/detection_rules/mappings.py b/detection_rules/mappings.py index fc64495c528..e67bc00dcb4 100644 --- a/detection_rules/mappings.py +++ b/detection_rules/mappings.py @@ -13,12 +13,12 @@ RTA_DIR = get_path("rta") -class RtaMappings(object): +class RtaMappings: """Rta-mapping helper class.""" def __init__(self): """Rta-mapping validation and prep.""" - self.mapping = load_etc_dump('rule-mapping.yml') # type: dict + self.mapping: dict = load_etc_dump('rule-mapping.yml') self.validate() self._rta_mapping = defaultdict(list) diff --git a/detection_rules/misc.py b/detection_rules/misc.py index 4194505cd59..bb0de69381d 100644 --- a/detection_rules/misc.py +++ b/detection_rules/misc.py @@ -89,7 +89,7 @@ def nested_get(_dict, dot_key, default=None): def nested_set(_dict, dot_key, value): - """Set a nested field from a a key in dot notation.""" + """Set a nested field from a key in dot notation.""" keys = dot_key.split('.') for key in keys[:-1]: _dict = _dict.setdefault(key, {}) @@ -100,6 +100,21 @@ def nested_set(_dict, dot_key, value): raise ValueError('dict cannot set a value to a non-dict for {}'.format(dot_key)) +def nest_from_dot(dots, value): + """Nest a dotted field and set the innermost value.""" + fields = dots.split('.') + + if not fields: + return {} + + nested = {fields.pop(): value} + + for field in reversed(fields): + nested = {field: nested} + + return nested + + def schema_prompt(name, value=None, required=False, **options): """Interactively prompt based on schema requirements.""" name = str(name) diff --git a/detection_rules/rule.py b/detection_rules/rule.py index 307b7edbb42..98825856bdb 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -338,7 +338,7 @@ class QueryValidator: @property def ast(self) -> Any: - raise NotImplementedError + raise NotImplementedError() @property def unique_fields(self) -> Any: diff --git a/kibana/connector.py b/kibana/connector.py index 7f79661cd2d..55aa3d0c4de 100644 --- a/kibana/connector.py +++ b/kibana/connector.py @@ -43,6 +43,9 @@ def __init__(self, cloud_id=None, kibana_url=None, verify=True, elasticsearch=No self.domain, self.es_uuid, self.kibana_uuid = \ base64.b64decode(cloud_info.encode("utf-8")).decode("utf-8").split("$") + if self.domain.endswith(':443'): + self.domain = self.domain[:-4] + kibana_url_from_cloud = f"https://{self.kibana_uuid}.{self.domain}:9243" if self.kibana_url and self.kibana_url != kibana_url_from_cloud: raise ValueError(f'kibana_url provided ({self.kibana_url}) does not match url derived from cloud_id ' diff --git a/kibana/resources.py b/kibana/resources.py index 8ec45cf19ef..8855bdfe1e9 100644 --- a/kibana/resources.py +++ b/kibana/resources.py @@ -4,7 +4,7 @@ # 2.0. import datetime -from typing import List, Type +from typing import List, Optional, Type from .connector import Kibana @@ -150,8 +150,9 @@ def __init__(self): raise NotImplementedError("Signals can't be instantiated yet") @classmethod - def search(cls, query_dsl: dict): - return Kibana.current().post(f"{cls.BASE_URI}/search", data=query_dsl) + def search(cls, query_dsl: dict, size: Optional[int] = 10): + payload = dict(size=size, **query_dsl) + return Kibana.current().post(f"{cls.BASE_URI}/search", data=payload) @classmethod def last_signal(cls) -> (int, datetime.datetime): @@ -179,8 +180,8 @@ def last_signal(cls) -> (int, datetime.datetime): return num_signals, last_seen @classmethod - def all(cls): - return cls.search({"query": {"bool": {"filter": {"match_all": {}}}}}) + def all(cls, size: Optional[int] = 10): + return cls.search({"query": {"bool": {"filter": {"match_all": {}}}}}, size=size) @classmethod def set_status_many(cls, signal_ids: List[str], status: str) -> dict: