Skip to content

Commit

Permalink
[Bug] Query validation failing to capture InSet edge case with ip fie…
Browse files Browse the repository at this point in the history
…ld types (#3572)

* Move test case to separate file

---------

Co-authored-by: Mika Ayenson <Mikaayenson@users.noreply.github.com>
Co-authored-by: shashank-elastic <91139415+shashank-elastic@users.noreply.github.com>

(cherry picked from commit a4a0bc6)
  • Loading branch information
eric-forte-elastic authored and github-actions[bot] committed May 6, 2024
1 parent e2a040b commit 7bd349b
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 51 deletions.
75 changes: 73 additions & 2 deletions detection_rules/rule_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
# 2.0.

"""Validation logic for rules containing queries."""
from functools import cached_property
from typing import List, Optional, Tuple, Union
from enum import Enum
from functools import cached_property, wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import eql
from eql import ast
from eql.parser import KvTree, LarkToEQL, NodeInfo, TypeHint
from eql.parser import _parse as base_parse
from marshmallow import ValidationError
from semver import Version

Expand All @@ -31,6 +35,73 @@
KQL_ERROR_TYPES = Union[kql.KqlCompileError, kql.KqlParseError]


class ExtendedTypeHint(Enum):
IP = "ip"

@classmethod
def primitives(cls):
"""Get all primitive types."""
return TypeHint.Boolean, TypeHint.Numeric, TypeHint.Null, TypeHint.String, ExtendedTypeHint.IP

def is_primitive(self):
"""Check if a type is a primitive."""
return self in self.primitives()


def custom_in_set(self, node: KvTree) -> NodeInfo:
"""Override and address the limitations of the eql in_set method."""
# return BaseInSetMethod(self, node)
outer, container = self.visit(node.child_trees) # type: (NodeInfo, list[NodeInfo])

if not outer.validate_type(ExtendedTypeHint.primitives()):
# can't compare non-primitives to sets
raise self._type_error(outer, ExtendedTypeHint.primitives())

# Check that everything inside the container has the same type as outside
error_message = "Unable to compare {expected_type} to {actual_type}"
for inner in container:
if not inner.validate_type(outer):
raise self._type_error(inner, outer, error_message)

if self._elasticsearch_syntax and hasattr(outer, "type_info"):
# Check edge case of in_set and ip/string comparison
outer_type = outer.type_info
if isinstance(self._schema, ecs.KqlSchema2Eql):
type_hint = self._schema.kql_schema.get(str(outer.node), "unknown")
if hasattr(self._schema, "type_mapping") and type_hint == "ip":
outer.type_info = ExtendedTypeHint.IP
for inner in container:
if not inner.validate_type(outer):
raise self._type_error(inner, outer, error_message)

# reset the type
outer.type_info = outer_type

# This will always evaluate to true/false, so it should be a boolean
term = ast.InSet(outer.node, [c.node for c in container])
nullable = outer.nullable or any(c.nullable for c in container)
return NodeInfo(term, TypeHint.Boolean, nullable=nullable, source=node)


def custom_base_parse_decorator(func: Callable[..., Any]) -> Callable[..., Any]:
"""Override and address the limitations of the eql in_set method."""

@wraps(func)
def wrapper(query: str, start: Optional[str] = None, **kwargs: Dict[str, Any]) -> Any:
original_in_set = LarkToEQL.in_set
LarkToEQL.in_set = custom_in_set
try:
result = func(query, start=start, **kwargs)
finally: # Using finally to ensure that the original method is restored
LarkToEQL.in_set = original_in_set
return result

return wrapper


eql.parser._parse = custom_base_parse_decorator(base_parse)


class KQLValidator(QueryValidator):
"""Specific fields for KQL query event types."""

Expand Down
68 changes: 68 additions & 0 deletions tests/test_python_library.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.

from detection_rules.rule_loader import RuleCollection

from .base import BaseRuleTest


class TestEQLInSet(BaseRuleTest):
"""Test EQL rule query in set override."""

def test_eql_in_set(self):
"""Test that the query validation is working correctly."""
rc = RuleCollection()
eql_rule = {
"metadata": {
"creation_date": "2020/12/15",
"integration": ["endpoint", "windows"],
"maturity": "production",
"min_stack_comments": "New fields added: required_fields, related_integrations, setup",
"min_stack_version": "8.3.0",
"updated_date": "2024/03/26",
},
"rule": {
"author": ["Elastic"],
"description": """
Test Rule.
""",
"false_positives": ["Fake."],
"from": "now-9m",
"index": ["winlogbeat-*", "logs-endpoint.events.*", "logs-windows.sysmon_operational-*"],
"language": "eql",
"license": "Elastic License v2",
"name": "Fake Test Rule",
"references": [
"https://example.com",
],
"risk_score": 47,
"rule_id": "4fffae5d-8b7d-4e48-88b1-979ed42fd9a3",
"severity": "medium",
"tags": [
"Domain: Endpoint",
"OS: Windows",
"Use Case: Threat Detection",
"Tactic: Execution",
"Data Source: Elastic Defend",
"Data Source: Sysmon",
],
"type": "eql",
"query": """
sequence by host.id, process.entity_id with maxspan = 5s
[network where destination.ip in ("127.0.0.1", "::1")]
""",
},
}
expected_error_message = r"Error in both stack and integrations checks:.*Unable to compare ip to string.*"
with self.assertRaisesRegex(ValueError, expected_error_message):
rc.load_dict(eql_rule)
# Change to appropriate destination.address field
eql_rule["rule"][
"query"
] = """
sequence by host.id, process.entity_id with maxspan = 10s
[network where destination.address in ("192.168.1.1", "::1")]
"""
rc.load_dict(eql_rule)
124 changes: 75 additions & 49 deletions tests/test_specific_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

import kql
from detection_rules.integrations import (
find_latest_compatible_version, load_integrations_manifests, load_integrations_schemas
find_latest_compatible_version,
load_integrations_manifests,
load_integrations_schemas,
)
from detection_rules.misc import load_current_package_version
from detection_rules.packaging import current_stack_version
Expand All @@ -23,31 +25,34 @@
from detection_rules.utils import get_path, load_rule_contents

from .base import BaseRuleTest

PACKAGE_STACK_VERSION = Version.parse(current_stack_version(), optional_minor_and_patch=True)


class TestEndpointQuery(BaseRuleTest):
"""Test endpoint-specific rules."""

@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.3.0"),
"Test only applicable to 8.3+ stacks since query updates are min_stacked at 8.3.0")
@unittest.skipIf(
PACKAGE_STACK_VERSION < Version.parse("8.3.0"),
"Test only applicable to 8.3+ stacks since query updates are min_stacked at 8.3.0",
)
def test_os_and_platform_in_query(self):
"""Test that all endpoint rules have an os defined and linux includes platform."""
for rule in self.production_rules:
if not rule.contents.data.get('language') in ('eql', 'kuery'):
if not rule.contents.data.get("language") in ("eql", "kuery"):
continue
if rule.path.parent.name not in ('windows', 'macos', 'linux'):
if rule.path.parent.name not in ("windows", "macos", "linux"):
# skip cross-platform for now
continue

ast = rule.contents.data.ast
fields = [str(f) for f in ast if isinstance(f, (kql.ast.Field, eql.ast.Field))]

err_msg = f'{self.rule_str(rule)} missing required field for endpoint rule'
if 'host.os.type' not in fields:
err_msg = f"{self.rule_str(rule)} missing required field for endpoint rule"
if "host.os.type" not in fields:
# Exception for Forwarded Events which contain Windows-only fields.
if rule.path.parent.name == 'windows' and not any(field.startswith('winlog.') for field in fields):
self.assertIn('host.os.type', fields, err_msg)
if rule.path.parent.name == "windows" and not any(field.startswith("winlog.") for field in fields):
self.assertIn("host.os.type", fields, err_msg)

# going to bypass this for now
# if rule.path.parent.name == 'linux':
Expand All @@ -58,48 +63,59 @@ def test_os_and_platform_in_query(self):
class TestNewTerms(BaseRuleTest):
"""Test new term rules."""

@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
@unittest.skipIf(
PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature."
)
def test_history_window_start(self):
"""Test new terms history window start field."""

for rule in self.production_rules:
if rule.contents.data.type == "new_terms":

# validate history window start field exists and is correct
assert rule.contents.data.new_terms.history_window_start, \
"new terms field found with no history_window_start field defined"
assert rule.contents.data.new_terms.history_window_start[0].field == "history_window_start", \
f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'"

@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
assert (
rule.contents.data.new_terms.history_window_start
), "new terms field found with no history_window_start field defined"
assert (
rule.contents.data.new_terms.history_window_start[0].field == "history_window_start"
), f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'"

@unittest.skipIf(
PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature."
)
def test_new_terms_field_exists(self):
# validate new terms and history window start fields are correct
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
assert rule.contents.data.new_terms.field == "new_terms_fields", \
f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type"
assert (
rule.contents.data.new_terms.field == "new_terms_fields"
), f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type"

@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
@unittest.skipIf(
PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature."
)
def test_new_terms_fields(self):
"""Test new terms fields are schema validated."""
# ecs validation
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
meta = rule.contents.metadata
feature_min_stack = Version.parse('8.4.0')
feature_min_stack = Version.parse("8.4.0")
current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True)
min_stack_version = Version.parse(meta.get("min_stack_version")) if \
meta.get("min_stack_version") else None
min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \
current_package_version else min_stack_version

assert min_stack_version >= feature_min_stack, \
f"New Terms rule types only compatible with {feature_min_stack}+"
ecs_version = get_stack_schemas()[str(min_stack_version)]['ecs']
beats_version = get_stack_schemas()[str(min_stack_version)]['beats']
min_stack_version = (
Version.parse(meta.get("min_stack_version")) if meta.get("min_stack_version") else None
)
min_stack_version = (
current_package_version
if min_stack_version is None or min_stack_version < current_package_version
else min_stack_version
)

assert (
min_stack_version >= feature_min_stack
), f"New Terms rule types only compatible with {feature_min_stack}+"
ecs_version = get_stack_schemas()[str(min_stack_version)]["ecs"]
beats_version = get_stack_schemas()[str(min_stack_version)]["beats"]

# checks if new terms field(s) are in ecs, beats non-ecs or integration schemas
queryvalidator = QueryValidator(rule.contents.data.query)
Expand All @@ -113,43 +129,53 @@ def test_new_terms_fields(self):
package=tag,
integration="",
rule_stack_version=min_stack_version,
packages_manifest=integration_manifests)
packages_manifest=integration_manifests,
)
if latest_tag_compat_ver:
integration_schema = integration_schemas[tag][latest_tag_compat_ver]
for policy_template in integration_schema.keys():
schema.update(**integration_schemas[tag][latest_tag_compat_ver][policy_template])
for new_terms_field in rule.contents.data.new_terms.value:
assert new_terms_field in schema.keys(), \
f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas"
assert (
new_terms_field in schema.keys()
), f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas"

@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
@unittest.skipIf(
PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature."
)
def test_new_terms_max_limit(self):
"""Test new terms max limit."""
# validates length of new_terms to stack version - https://github.com/elastic/kibana/issues/142862
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
meta = rule.contents.metadata
feature_min_stack = Version.parse('8.4.0')
feature_min_stack_extended_fields = Version.parse('8.6.0')
feature_min_stack = Version.parse("8.4.0")
feature_min_stack_extended_fields = Version.parse("8.6.0")
current_package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True)
min_stack_version = Version.parse(meta.get("min_stack_version")) if \
meta.get("min_stack_version") else None
min_stack_version = current_package_version if min_stack_version is None or min_stack_version < \
current_package_version else min_stack_version
min_stack_version = (
Version.parse(meta.get("min_stack_version")) if meta.get("min_stack_version") else None
)
min_stack_version = (
current_package_version
if min_stack_version is None or min_stack_version < current_package_version
else min_stack_version
)
if feature_min_stack <= min_stack_version < feature_min_stack_extended_fields:
assert len(rule.contents.data.new_terms.value) == 1, \
f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}"
assert (
len(rule.contents.data.new_terms.value) == 1
), f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}"

@unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.6.0"),
"Test only applicable to 8.4+ stacks for new terms feature.")
@unittest.skipIf(
PACKAGE_STACK_VERSION < Version.parse("8.6.0"), "Test only applicable to 8.4+ stacks for new terms feature."
)
def test_new_terms_fields_unique(self):
"""Test new terms fields are unique."""
# validate fields are unique
for rule in self.production_rules:
if rule.contents.data.type == "new_terms":
assert len(set(rule.contents.data.new_terms.value)) == len(rule.contents.data.new_terms.value), \
f"new terms fields values are not unique - {rule.contents.data.new_terms.value}"
assert len(set(rule.contents.data.new_terms.value)) == len(
rule.contents.data.new_terms.value
), f"new terms fields values are not unique - {rule.contents.data.new_terms.value}"


class TestESQLRules(BaseRuleTest):
Expand Down

0 comments on commit 7bd349b

Please sign in to comment.