diff --git a/core/schemas/indicator.py b/core/schemas/indicator.py index 6742139fe..588ff9e5e 100644 --- a/core/schemas/indicator.py +++ b/core/schemas/indicator.py @@ -12,7 +12,7 @@ from core import database_arango from core.helpers import now -from core.schemas.model import YetiModel +from core.schemas.model import YetiTagModel def future(): @@ -44,7 +44,7 @@ class DiamondModel(Enum): victim = "victim" -class Indicator(YetiModel, database_arango.ArangoYetiConnector): +class Indicator(YetiTagModel, database_arango.ArangoYetiConnector): _collection_name: ClassVar[str] = "indicators" _type_filter: ClassVar[str] = "" _root_type: Literal["indicator"] = "indicator" diff --git a/core/web/apiv2/indicators.py b/core/web/apiv2/indicators.py index af7928018..22a852524 100644 --- a/core/web/apiv2/indicators.py +++ b/core/web/apiv2/indicators.py @@ -1,9 +1,7 @@ -from typing import Iterable - from fastapi import APIRouter, HTTPException from pydantic import BaseModel, ConfigDict -from core.schemas import indicator +from core.schemas import graph, indicator # Request schemas @@ -37,13 +35,23 @@ class IndicatorSearchResponse(BaseModel): total: int -# API endpoints -router = APIRouter() +class IndicatorTagRequest(BaseModel): + model_config = ConfigDict(extra="forbid") + ids: list[str] + tags: list[str] + strict: bool = False + + +class IndicatorTagResponse(BaseModel): + model_config = ConfigDict(extra="forbid") -@router.get("/") -async def indicators_root() -> Iterable[indicator.IndicatorTypes]: - return indicator.Indicator.list() + tagged: int + tags: dict[str, dict[str, graph.TagRelationship]] + + +# API endpoints +router = APIRouter() @router.post("/") @@ -87,6 +95,7 @@ async def details(indicator_id) -> indicator.IndicatorTypes: db_indicator: indicator.IndicatorTypes = indicator.Indicator.get(indicator_id) # type: ignore if not db_indicator: raise HTTPException(status_code=404, detail="indicator not found") + db_indicator.get_tags() return db_indicator @@ -105,13 +114,37 @@ async def delete(indicator_id: str) -> None: async def search(request: IndicatorSearchRequest) -> IndicatorSearchResponse: """Searches for indicators.""" query = request.query + tags = query.pop("tags", []) if request.type: query["type"] = request.type indicators, total = indicator.Indicator.filter( query_args=query, + tag_filter=tags, offset=request.page * request.count, count=request.count, sorting=request.sorting, aliases=request.filter_aliases, + graph_queries=[("tags", "tagged", "outbound", "name")], ) return IndicatorSearchResponse(indicators=indicators, total=total) + + +@router.post("/tag") +async def tag(request: IndicatorTagRequest) -> IndicatorTagResponse: + """Tags entities.""" + indicators = [] + for indicator_id in request.ids: + db_indicator = indicator.Indicator.get(indicator_id) + if not db_indicator: + raise HTTPException( + status_code=404, + detail=f"Tagging request contained an unknown indicator: ID:{indicator_id}", + ) + indicators.append(db_indicator) + + indicator_tags = {} + for db_indicator in indicators: + db_indicator.tag(request.tags, strict=request.strict) + indicator_tags[db_indicator.extended_id] = db_indicator.tags + + return IndicatorTagResponse(tagged=len(indicators), tags=indicator_tags) diff --git a/tests/apiv2/indicators.py b/tests/apiv2/indicators.py index f53754a46..a021d4d1b 100644 --- a/tests/apiv2/indicators.py +++ b/tests/apiv2/indicators.py @@ -29,6 +29,7 @@ def setUp(self) -> None: location="filesystem", diamond=indicator.DiamondModel.capability, ).save() + self.indicator1.tag(["hextag"]) self.indicator2 = indicator.Regex( name="localhost", pattern="127.0.0.1", @@ -39,10 +40,6 @@ def setUp(self) -> None: def tearDown(self) -> None: database_arango.db.clear() - def test_get_indicators(self): - response = client.get("/api/v2/indicators/") - self.assertEqual(response.status_code, 200) - def test_new_indicator(self): indicator_dict = { "name": "otherRegex", @@ -62,12 +59,12 @@ def test_new_indicator(self): def test_get_indicator(self): response = client.get(f"/api/v2/indicators/{self.indicator1.id}") - self.assertEqual(response.status_code, 200) data = response.json() + self.assertEqual(response.status_code, 200, data) self.assertEqual(data["name"], "hex") self.assertEqual(data["type"], "regex") - def test_sarch_indicators(self): + def test_search_indicators(self): response = client.post( "/api/v2/indicators/search", json={"query": {"name": "he"}, "type": "regex"} ) @@ -77,6 +74,21 @@ def test_sarch_indicators(self): self.assertEqual(data["indicators"][0]["name"], "hex") self.assertEqual(data["indicators"][0]["type"], "regex") + # check tag + self.assertEqual(len(data["indicators"][0]["tags"]), 1) + self.assertIn("hextag", data["indicators"][0]["tags"]) + + def test_search_indicators_tagged(self): + response = client.post( + "/api/v2/indicators/search", + json={"query": {"name": "", "tags": ["hextag"]}, "type": "regex"}, + ) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(len(data["indicators"]), 1) + self.assertEqual(data["indicators"][0]["name"], "hex") + self.assertEqual(data["indicators"][0]["type"], "regex") + def test_search_indicators_subfields(self): response = client.post( "/api/v2/indicators/search",