Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce tags to indicator objects #1026

Merged
merged 11 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/schemas/indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from core import database_arango
from core.helpers import now
from core.schemas.model import YetiModel
from core.schemas.model import YetiTagModel


def future():
Expand Down Expand Up @@ -44,7 +44,7 @@ class DiamondModel(Enum):
victim = "victim"


class Indicator(YetiModel, database_arango.ArangoYetiConnector):
class Indicator(YetiTagModel, database_arango.ArangoYetiConnector):
_collection_name: ClassVar[str] = "indicators"
_type_filter: ClassVar[str] = ""
_root_type: Literal["indicator"] = "indicator"
Expand Down
49 changes: 41 additions & 8 deletions core/web/apiv2/indicators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Iterable

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, ConfigDict

from core.schemas import indicator
from core.schemas import graph, indicator


# Request schemas
Expand Down Expand Up @@ -37,13 +35,23 @@ class IndicatorSearchResponse(BaseModel):
total: int


# API endpoints
router = APIRouter()
class IndicatorTagRequest(BaseModel):
model_config = ConfigDict(extra="forbid")

ids: list[str]
tags: list[str]
strict: bool = False


class IndicatorTagResponse(BaseModel):
model_config = ConfigDict(extra="forbid")

@router.get("/")
async def indicators_root() -> Iterable[indicator.IndicatorTypes]:
return indicator.Indicator.list()
tagged: int
tags: dict[str, dict[str, graph.TagRelationship]]


# API endpoints
router = APIRouter()


@router.post("/")
Expand Down Expand Up @@ -87,6 +95,7 @@ async def details(indicator_id) -> indicator.IndicatorTypes:
db_indicator: indicator.IndicatorTypes = indicator.Indicator.get(indicator_id) # type: ignore
if not db_indicator:
raise HTTPException(status_code=404, detail="indicator not found")
db_indicator.get_tags()
return db_indicator


Expand All @@ -105,13 +114,37 @@ async def delete(indicator_id: str) -> None:
async def search(request: IndicatorSearchRequest) -> IndicatorSearchResponse:
"""Searches for indicators."""
query = request.query
tags = query.pop("tags", [])
if request.type:
query["type"] = request.type
indicators, total = indicator.Indicator.filter(
query_args=query,
tag_filter=tags,
offset=request.page * request.count,
count=request.count,
sorting=request.sorting,
aliases=request.filter_aliases,
graph_queries=[("tags", "tagged", "outbound", "name")],
)
return IndicatorSearchResponse(indicators=indicators, total=total)


@router.post("/tag")
async def tag(request: IndicatorTagRequest) -> IndicatorTagResponse:
"""Tags entities."""
indicators = []
for indicator_id in request.ids:
db_indicator = indicator.Indicator.get(indicator_id)
if not db_indicator:
raise HTTPException(
status_code=404,
detail=f"Tagging request contained an unknown indicator: ID:{indicator_id}",
)
indicators.append(db_indicator)

indicator_tags = {}
for db_indicator in indicators:
db_indicator.tag(request.tags, strict=request.strict)
indicator_tags[db_indicator.extended_id] = db_indicator.tags

return IndicatorTagResponse(tagged=len(indicators), tags=indicator_tags)
24 changes: 18 additions & 6 deletions tests/apiv2/indicators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def setUp(self) -> None:
location="filesystem",
diamond=indicator.DiamondModel.capability,
).save()
self.indicator1.tag(["hextag"])
self.indicator2 = indicator.Regex(
name="localhost",
pattern="127.0.0.1",
Expand All @@ -39,10 +40,6 @@ def setUp(self) -> None:
def tearDown(self) -> None:
database_arango.db.clear()

def test_get_indicators(self):
response = client.get("/api/v2/indicators/")
self.assertEqual(response.status_code, 200)

def test_new_indicator(self):
indicator_dict = {
"name": "otherRegex",
Expand All @@ -62,12 +59,12 @@ def test_new_indicator(self):

def test_get_indicator(self):
response = client.get(f"/api/v2/indicators/{self.indicator1.id}")
self.assertEqual(response.status_code, 200)
data = response.json()
self.assertEqual(response.status_code, 200, data)
self.assertEqual(data["name"], "hex")
self.assertEqual(data["type"], "regex")

def test_sarch_indicators(self):
def test_search_indicators(self):
response = client.post(
"/api/v2/indicators/search", json={"query": {"name": "he"}, "type": "regex"}
)
Expand All @@ -77,6 +74,21 @@ def test_sarch_indicators(self):
self.assertEqual(data["indicators"][0]["name"], "hex")
self.assertEqual(data["indicators"][0]["type"], "regex")

# check tag
self.assertEqual(len(data["indicators"][0]["tags"]), 1)
self.assertIn("hextag", data["indicators"][0]["tags"])

def test_search_indicators_tagged(self):
response = client.post(
"/api/v2/indicators/search",
json={"query": {"name": "", "tags": ["hextag"]}, "type": "regex"},
)
self.assertEqual(response.status_code, 200)
data = response.json()
self.assertEqual(len(data["indicators"]), 1)
self.assertEqual(data["indicators"][0]["name"], "hex")
self.assertEqual(data["indicators"][0]["type"], "regex")

def test_search_indicators_subfields(self):
response = client.post(
"/api/v2/indicators/search",
Expand Down
Loading