Skip to content

Commit

Permalink
Graph search can now select target vertices based on root_type (previ…
Browse files Browse the repository at this point in the history
…ously only leaf types) (#1065)
  • Loading branch information
tomchop authored Apr 19, 2024
1 parent f0df5c4 commit 15d5344
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
7 changes: 5 additions & 2 deletions core/database_arango.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,8 @@ def neighbors(
count: int = 0,
) -> tuple[
dict[
str, "observable.Observable | entity.Entity | indicator.Indicator | tag.Tag"
str,
"observable.ObservableTypes | entity.EntityTypes | indicator.IndicatorTypes | tag.Tag",
],
List[List["Relationship | TagRelationship"]],
int,
Expand Down Expand Up @@ -582,7 +583,9 @@ def neighbors(
query_filter = "FILTER e.type IN @link_types"
if target_types:
args["target_types"] = target_types
query_filter = "FILTER v.type IN @target_types"
query_filter = (
"FILTER (v.type IN @target_types OR v.root_type IN @target_types)"
)

limit = ""
if count != 0:
Expand Down
45 changes: 45 additions & 0 deletions tests/apiv2/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,51 @@ def test_neighbors_strongly_typed(self):
self.assertEqual(neighbor["query_type"], "opensearch")
self.assertEqual(neighbor["target_systems"], ["system1"])

def test_neighbors_target_types(self):
self.entity1.link_to(self.observable1, "uses", "asd")
self.entity1.link_to(self.observable2, "uses", "asd")
response = client.post(
"/api/v2/graph/search",
json={
"source": self.entity1.extended_id,
"hops": 1,
"graph": "links",
"direction": "any",
"target_types": ["hostname"],
"include_original": False,
},
)
data = response.json()
self.assertEqual(response.status_code, 200, data)
self.assertEqual(len(data["vertices"]), 1)
self.assertEqual(
data["vertices"][self.observable1.extended_id]["value"], "tomchop.me"
)

def test_neighbors_target_types_root_type(self):
self.entity1.link_to(self.observable1, "uses", "asd")
self.entity1.link_to(self.observable2, "uses", "asd")
response = client.post(
"/api/v2/graph/search",
json={
"source": self.entity1.extended_id,
"hops": 1,
"graph": "links",
"direction": "any",
"target_types": ["observable"],
"include_original": False,
},
)
data = response.json()
self.assertEqual(response.status_code, 200, data)
self.assertEqual(len(data["vertices"]), 2)
self.assertEqual(
data["vertices"][self.observable1.extended_id]["value"], "tomchop.me"
)
self.assertEqual(
data["vertices"][self.observable2.extended_id]["value"], "127.0.0.1"
)

def test_add_link(self):
response = client.post(
"/api/v2/graph/add",
Expand Down

0 comments on commit 15d5344

Please sign in to comment.