Skip to content

Commit

Permalink
Feature: Ability to sort and filter graph traversal (#1067)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomchop authored May 2, 2024
1 parent 48d6f1f commit 3303b9d
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 7 deletions.
31 changes: 29 additions & 2 deletions core/database_arango.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

if TYPE_CHECKING:
from core.schemas import entity, indicator, observable
from core.schemas.graph import Relationship, RelationshipTypes, TagRelationship
from core.schemas.graph import (
GraphFilter,
Relationship,
RelationshipTypes,
TagRelationship,
)
from core.schemas.tag import Tag

import requests
Expand Down Expand Up @@ -211,7 +216,12 @@ def _update(self, document_json):
self._get_collection().update_match(filters, document)

logging.debug(f"filters: {filters}")
newdoc = list(self._get_collection().find(filters, limit=1))[0]
try:
newdoc = list(self._get_collection().find(filters, limit=1))[0]
except IndexError as exception:
msg = f"Update failed when adding {document_json}: {exception}"
logging.error(msg)
raise RuntimeError(msg)

newdoc["__id"] = newdoc.pop("_key")
return newdoc
Expand Down Expand Up @@ -544,11 +554,13 @@ def neighbors(
target_types: List[str] = [],
direction: str = "any",
graph: str = "links",
filter: List["GraphFilter"] = [],
include_original: bool = False,
min_hops: int = 1,
max_hops: int = 1,
offset: int = 0,
count: int = 0,
sorting: List[tuple[str, bool]] = [],
) -> tuple[
dict[
str,
Expand Down Expand Up @@ -583,6 +595,11 @@ def neighbors(
"extended_id": self.extended_id,
"@graph": graph,
}
sorts = []
for field, asc in sorting:
sorts.append(f'p.edges[0].{field} {"ASC" if asc else "DESC"}')
sorting_aql = f"SORT {', '.join(sorts)}" if sorts else ""

if link_types:
args["link_types"] = link_types
query_filter = "FILTER e.type IN @link_types"
Expand All @@ -591,6 +608,15 @@ def neighbors(
query_filter = (
"FILTER (v.type IN @target_types OR v.root_type IN @target_types)"
)
if filter:
filters = []
for i, f in enumerate(filter):
filters.append(
f"(p.edges[*].@filter_key{i} {f.operator} @filter_value{i} OR p.vertices[*].@filter_key{i} {f.operator} @filter_value{i})"
)
args[f"filter_key{i}"] = f.key
args[f"filter_value{i}"] = f.value
query_filter += f"FILTER {' OR '.join(filters)}"

limit = ""
if count != 0:
Expand All @@ -613,6 +639,7 @@ def neighbors(
RETURN MERGE(observable, {{tags: MERGE(innertags)}})
)
{limit}
{sorting_aql}
RETURN {{ vertices: v_with_tags, g: p }}
"""
cursor = self._db.aql.execute(aql, bind_vars=args, count=True, full_count=True)
Expand Down
3 changes: 2 additions & 1 deletion core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if TYPE_CHECKING:
from core.schemas import entity, indicator, observable, tag
from core.schemas.graph import Relationship, TagRelationship
from core.schemas.graph import GraphFilter, Relationship, TagRelationship

TYetiObject = TypeVar("TYetiObject")

Expand Down Expand Up @@ -103,6 +103,7 @@ def neighbors(
target_types: List[str] = [],
direction: str = "any",
graph: str = "links",
filter: List["GraphFilter"] = [],
include_original: bool = False,
min_hops: int = 1,
max_hops: int = 1,
Expand Down
9 changes: 8 additions & 1 deletion core/schemas/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@

from core import database_arango


# Database model


class GraphFilter(BaseModel):
key: str
value: str
operator: str


# Relationship and TagRelationship do not inherit from YetiModel
# because they represent and id in the form of collection_name/id
class Relationship(BaseModel, database_arango.ArangoYetiConnector):
Expand Down
5 changes: 5 additions & 0 deletions core/web/apiv2/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic.functional_validators import field_validator

from core.schemas import dfiq, entity, graph, indicator, observable, tag
from core.schemas.graph import GraphFilter
from core.schemas.observable import ObservableType

GRAPH_TYPE_MAPPINGS = {} # type: dict[str, Type[entity.Entity] | Type[observable.Observable] | Type[indicator.Indicator]]
Expand Down Expand Up @@ -34,9 +35,11 @@ class GraphSearchRequest(BaseModel):
max_hops: int | None = None
graph: str
direction: GraphDirection
filter: list[GraphFilter] = []
include_original: bool
count: int = 50
page: int = 0
sorting: list[tuple[str, bool]] = []

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -134,12 +137,14 @@ async def search(request: GraphSearchRequest) -> GraphSearchResponse:
link_types=request.link_types,
target_types=request.target_types,
direction=request.direction,
filter=request.filter,
include_original=request.include_original,
graph=request.graph,
min_hops=request.min_hops or request.hops,
max_hops=request.max_hops or request.hops,
count=request.count,
offset=request.page * request.count,
sorting=request.sorting,
)
return GraphSearchResponse(vertices=vertices, paths=paths, total=total)

Expand Down
2 changes: 1 addition & 1 deletion core/web/apiv2/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def bulk_add(request: NewBulkObservableAddRequest) -> BulkObservableAddRes
observable = cls(value=new_observable.value).save()
if new_observable.tags:
observable = observable.tag(new_observable.tags)
except ValueError:
except (ValueError, RuntimeError):
response.failed.append(new_observable.value)
continue
response.added.append(observable)
Expand Down
77 changes: 75 additions & 2 deletions tests/schemas/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

from core import database_arango
from core.schemas.entity import Malware
from core.schemas.graph import Relationship
from core.schemas.graph import GraphFilter, Relationship
from core.schemas.observables import hostname, ipv4
from core.web import webapp

client = TestClient(webapp.app)


class ObservableTest(unittest.TestCase):
class GraphTest(unittest.TestCase):
def setUp(self) -> None:
database_arango.db.connect(database="yeti_test")
database_arango.db.clear()
Expand Down Expand Up @@ -130,3 +130,76 @@ def test_neighbors_target_types(self):
vertices, edges, edge_count = self.observable1.neighbors()
self.assertEqual(len(vertices), 2)
self.assertEqual(edge_count, 2)

def test_neighbors_filter(self):
"""Tests that a link between two nodes is bidirectional."""
self.observable1.link_to(self.observable2, "a", "description_aaaa")
self.observable1.link_to(self.observable3, "c", "description_ccc")

# filter on edge description
vertices, edges, edge_count = self.observable1.neighbors(
filter=[GraphFilter(key="description", value="_ccc", operator="=~")]
)
self.assertEqual(len(vertices), 1)
self.assertEqual(edge_count, 1)
self.assertEqual(
vertices[self.observable3.extended_id].value, self.observable3.value
)

# filter on vertice value
vertices, edges, edge_count = self.observable1.neighbors(
filter=[GraphFilter(key="value", value="8.8", operator="=~")]
)
self.assertEqual(len(vertices), 1)
self.assertEqual(edge_count, 1)
self.assertEqual(
vertices[self.observable3.extended_id].value, self.observable3.value
)

vertices, edges, edge_count = self.observable1.neighbors()
self.assertEqual(len(vertices), 2)
self.assertEqual(edge_count, 2)

def test_neighbors_filter_two_hops(self):
"""Tests that a link between two nodes is bidirectional."""
self.observable1.link_to(self.observable2, "a", "description_aaaa_to_b")
self.observable2.link_to(self.observable3, "b", "description_bbbb_to_c")
observable4 = ipv4.IPv4(value="1.1.1.1").save()
self.observable2.link_to(observable4, "c", "description_bbbb_to_d")

vertices, edges, edge_count = self.observable1.neighbors(min_hops=1, max_hops=2)
self.assertEqual(len(vertices), 3)
self.assertEqual(len(edges), 3)
self.assertEqual(len(edges[0]), 1) # First hop counts as a path
self.assertEqual(len(edges[1]), 2) # First two-hop path
self.assertEqual(len(edges[2]), 2) # Second two-hop path

vertices, edges, edge_count = self.observable1.neighbors(
min_hops=1,
max_hops=2,
filter=[GraphFilter(key="description", value="bbbb_to_d", operator="=~")],
)
self.assertEqual(len(vertices), 2)
self.assertEqual(len(edges), 1)
self.assertEqual(len(edges[0]), 2)
self.assertEqual(edges[0][0].source, self.observable1.extended_id)
self.assertEqual(edges[0][0].target, self.observable2.extended_id)
self.assertEqual(edges[0][1].source, self.observable2.extended_id)
self.assertEqual(edges[0][1].target, observable4.extended_id)
self.assertEqual(edges[0][0].description, "description_aaaa_to_b")
self.assertEqual(edges[0][1].description, "description_bbbb_to_d")

vertices, edges, edge_count = self.observable1.neighbors(
min_hops=1,
max_hops=3,
filter=[GraphFilter(key="value", value="1.1.1.1", operator="=~")],
)
self.assertEqual(len(vertices), 2)
self.assertEqual(len(edges), 1) # Only one path, two-hops
self.assertEqual(len(edges[0]), 2)
self.assertEqual(edges[0][0].source, self.observable1.extended_id)
self.assertEqual(edges[0][0].target, self.observable2.extended_id)
self.assertEqual(edges[0][1].source, self.observable2.extended_id)
self.assertEqual(edges[0][1].target, observable4.extended_id)
self.assertEqual(edges[0][0].description, "description_aaaa_to_b")
self.assertEqual(edges[0][1].description, "description_bbbb_to_d")

0 comments on commit 3303b9d

Please sign in to comment.