From 50f942191f44f29456e18fe91a7536f02cb56efd Mon Sep 17 00:00:00 2001 From: Stephen Lincoln Date: Thu, 7 Nov 2024 15:10:44 -0500 Subject: [PATCH] Formatting --- sigmaiq/backends/crowdstrike/crowdstrike.py | 2 + .../backends/elasticsearch/elasticsearch.py | 1 + sigmaiq/backends/kusto/kusto.py | 3 + sigmaiq/llm/base.py | 2 +- sigmaiq/llm/toolkits/base.py | 42 ++++++++--- sigmaiq/llm/toolkits/prompts.py | 2 +- sigmaiq/llm/tools/find_sigma_rule.py | 1 + sigmaiq/sigmaiq_backend_factory.py | 5 +- sigmaiq/utils/sigmaiq/sigmaiq_utils.py | 75 ++++++++++--------- tests/test_backend_factory.py | 20 +---- tests/test_llm_components.py | 43 ++++++----- tests/test_sigmaiq_utils.py | 49 ++++++------ 12 files changed, 133 insertions(+), 112 deletions(-) diff --git a/sigmaiq/backends/crowdstrike/crowdstrike.py b/sigmaiq/backends/crowdstrike/crowdstrike.py index 6aa1021..3cacdff 100644 --- a/sigmaiq/backends/crowdstrike/crowdstrike.py +++ b/sigmaiq/backends/crowdstrike/crowdstrike.py @@ -6,6 +6,7 @@ class SigmAIQCrowdstrikeSplunkBackend(AbstractGenericSigmAIQBackendClass, SplunkBackend): """SigmAIQ backend interface for the pySigma Splunk Backend library to translate a SigmaRule object to a Splunk search query with the Crowdstrike FDR format""" + custom_formats = {} associated_pipelines = ["crowdstrike_fdr"] default_pipeline = "crowdstrike_fdr" @@ -15,6 +16,7 @@ class SigmAIQCrowdstrikeSplunkBackend(AbstractGenericSigmAIQBackendClass, Splunk class SigmAIQCrowdstrikeLogscaleBackend(AbstractGenericSigmAIQBackendClass, LogScaleBackend): """SigmAIQ backend interface for the pySigma Logscale Backend library to translate a SigmaRule object to a Logscale search query with the Crowdstrike Falcon format""" + custom_formats = {} associated_pipelines = ["crowdstrike_falcon"] default_pipeline = "crowdstrike_falcon" diff --git a/sigmaiq/backends/elasticsearch/elasticsearch.py b/sigmaiq/backends/elasticsearch/elasticsearch.py index 879846e..5fe2d9b 100644 --- a/sigmaiq/backends/elasticsearch/elasticsearch.py +++ b/sigmaiq/backends/elasticsearch/elasticsearch.py @@ -5,6 +5,7 @@ class SigmAIQElasticsearchBackend(AbstractGenericSigmAIQBackendClass, LuceneBackend): """SigmAIQ backend interface for the pySigma Elasticsearch Backend library to translate a SigmaRule object to an Elasticsearch search query""" + custom_formats = {} associated_pipelines = [ "ecs_windows", diff --git a/sigmaiq/backends/kusto/kusto.py b/sigmaiq/backends/kusto/kusto.py index 2434970..7faf559 100644 --- a/sigmaiq/backends/kusto/kusto.py +++ b/sigmaiq/backends/kusto/kusto.py @@ -5,6 +5,7 @@ class SigmAIQDefenderXDRBackend(AbstractGenericSigmAIQBackendClass, KustoBackend): """SigmAIQ backend interface for the pySigma Kusto Backend library to translate a SigmaRule object to a Kusto search query with the Microsoft Defender XDR format""" + custom_formats = {} associated_pipelines = ["microsoft_xdr"] default_pipeline = "microsoft_xdr" @@ -13,6 +14,7 @@ class SigmAIQDefenderXDRBackend(AbstractGenericSigmAIQBackendClass, KustoBackend class SigmAIQSentinelASIMBackend(AbstractGenericSigmAIQBackendClass, KustoBackend): """SigmAIQ backend interface for the pySigma Kusto Backend library to translate a SigmaRule object to a Kusto search query with the Microsoft Sentinel ASIM format""" + custom_formats = {} associated_pipelines = ["sentinel_asim"] default_pipeline = "sentinel_asim" @@ -21,6 +23,7 @@ class SigmAIQSentinelASIMBackend(AbstractGenericSigmAIQBackendClass, KustoBacken class SigmAIQAzureMonitorBackend(AbstractGenericSigmAIQBackendClass, KustoBackend): """SigmAIQ backend interface for the pySigma Kusto Backend library to translate a SigmaRule object to a Kusto search query with the Microsoft Azure Monitor format""" + custom_formats = {} associated_pipelines = ["azure_monitor"] default_pipeline = "azure_monitor" diff --git a/sigmaiq/llm/base.py b/sigmaiq/llm/base.py index c94de47..2d8bee8 100644 --- a/sigmaiq/llm/base.py +++ b/sigmaiq/llm/base.py @@ -36,7 +36,7 @@ def __init__( rule_dir: str = None, vector_store_dir: str = None, embedding_model: OpenAIEmbeddings = None, - embedding_function: Type[Embeddings] = OpenAIEmbeddings, #TODO RS : Consolidate this with embedding_model + embedding_function: Type[Embeddings] = OpenAIEmbeddings, # TODO RS : Consolidate this with embedding_model vector_store: Type[VectorStore] = FAISS, rule_loader: Type[BaseLoader] = DirectoryLoader, rule_splitter: Type[BaseDocumentTransformer] = CharacterTextSplitter, diff --git a/sigmaiq/llm/toolkits/base.py b/sigmaiq/llm/toolkits/base.py index 6b04b5c..ad10582 100644 --- a/sigmaiq/llm/toolkits/base.py +++ b/sigmaiq/llm/toolkits/base.py @@ -8,16 +8,21 @@ from langchain.agents.format_scratchpad import format_to_openai_function_messages from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser from langchain.prompts import ChatPromptTemplate + # langchain typing -from langchain.schema import (AgentAction, AgentFinish, OutputParserException, AIMessage, BaseMessage) +from langchain.schema import AgentAction, AgentFinish, OutputParserException, AIMessage, BaseMessage from langchain.schema.agent import AgentActionMessageLog from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import (AIMessage, BaseMessage, ) +from langchain.schema.messages import ( + AIMessage, + BaseMessage, +) from langchain.schema.vectorstore import VectorStore from langchain_core.utils.function_calling import convert_to_openai_function from langchain_openai import ChatOpenAI from sigmaiq.llm.toolkits.prompts import SIGMA_AGENT_PROMPT + # sigmaiq from sigmaiq.llm.toolkits.sigma_toolkit import SigmaToolkit @@ -31,7 +36,7 @@ def create_sigma_agent( ) -> AgentExecutor: if sigma_vectorstore is None: raise ValueError("sigma_vectorstore must be provided") - + if rule_creation_llm is None: rule_creation_llm = ChatOpenAI(model="gpt-4o") @@ -40,20 +45,32 @@ def create_sigma_agent( # Assert if any of the tools does not have arun for tool in tools: - assert hasattr(tool, 'arun'), f"Tool {tool.name} does not have an 'arun' method" + assert hasattr(tool, "arun"), f"Tool {tool.name} does not have an 'arun' method" # Create OpenAI Function for each tool for the agent LLM, so we can create an OpenAI Function AgentExecutor llm_with_tools = rule_creation_llm.bind(functions=[convert_to_openai_function(t) for t in tools]) # Create the agent prompt = SIGMA_AGENT_PROMPT - agent = ({"input": lambda x: x["input"], "agent_scratchpad": lambda x: format_to_openai_function_messages( - x["intermediate_steps"]), } | prompt | llm_with_tools | CustomOpenAIFunctionsAgentOutputParser()) + agent = ( + { + "input": lambda x: x["input"], + "agent_scratchpad": lambda x: format_to_openai_function_messages(x["intermediate_steps"]), + } + | prompt + | llm_with_tools + | CustomOpenAIFunctionsAgentOutputParser() + ) # Create and return the AgentExecutor - agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, - return_intermediate_steps=return_intermediate_steps, handle_parsing_errors=True, - **(agent_executor_kwargs or {})) + agent_executor = AgentExecutor( + agent=agent, + tools=tools, + verbose=verbose, + return_intermediate_steps=return_intermediate_steps, + handle_parsing_errors=True, + **(agent_executor_kwargs or {}), + ) return agent_executor @@ -70,7 +87,7 @@ def parse(self, message: Union[str, BaseMessage]) -> Union[AgentAction, AgentFin raise ValueError("Expected an AIMessage object, got a string") if not isinstance(message, AIMessage): raise TypeError(f"Expected an AI message got {type(message)}") - + return self._parse_ai_message(message) @staticmethod @@ -83,8 +100,9 @@ def _parse_ai_message(message: AIMessage) -> Union[AgentAction, AgentFinish]: try: _tool_input = json.loads(function_call["arguments"].strip(), strict=False) # HACK except JSONDecodeError: - raise OutputParserException(f"Could not parse tool input: {function_call} because " - f"the `arguments` is not valid JSON.") + raise OutputParserException( + f"Could not parse tool input: {function_call} because " f"the `arguments` is not valid JSON." + ) # HACK HACK HACK: # The code that encodes tool input into Open AI uses a special variable diff --git a/sigmaiq/llm/toolkits/prompts.py b/sigmaiq/llm/toolkits/prompts.py index 9d28ff9..20af7e4 100644 --- a/sigmaiq/llm/toolkits/prompts.py +++ b/sigmaiq/llm/toolkits/prompts.py @@ -12,7 +12,7 @@ "3. create_sigma_rule_vectorstore: Creates new Sigma Rule from the users input, as well as rules in a sigma rule vectorstore to use as context based on the users question. If the user's question already contains a query, use 'query_to_sigma_rule' instead." "4. query_to_sigma_rule: Converts/translates a product/SIEM/backend query or search from the query language into a YAML Sigma Rule." "Do not use 'translate_sigma_rule' unless the user explicitly asks for a Sigma Rule to be converted or translated " - "into a query for a specific backend, pipeline, and/or output format." + "into a query for a specific backend, pipeline, and/or output format.", ), ("user", "{input}"), MessagesPlaceholder(variable_name="agent_scratchpad"), diff --git a/sigmaiq/llm/tools/find_sigma_rule.py b/sigmaiq/llm/tools/find_sigma_rule.py index 6066e1d..254eaba 100644 --- a/sigmaiq/llm/tools/find_sigma_rule.py +++ b/sigmaiq/llm/tools/find_sigma_rule.py @@ -36,6 +36,7 @@ class FindSigmaRuleTool(BaseTool): class Config: """Configuration for this pydantic object.""" + extra = Extra.forbid def _run(self, query: Union[str, dict]) -> str: diff --git a/sigmaiq/sigmaiq_backend_factory.py b/sigmaiq/sigmaiq_backend_factory.py index 8f923bc..5bdd701 100644 --- a/sigmaiq/sigmaiq_backend_factory.py +++ b/sigmaiq/sigmaiq_backend_factory.py @@ -63,7 +63,10 @@ class SigmAIQBackend: """ def __init__( - self, backend: str, processing_pipeline: Optional[Union[str, list, ProcessingPipeline]] = None, output_format: Optional[str] = None + self, + backend: str, + processing_pipeline: Optional[Union[str, list, ProcessingPipeline]] = None, + output_format: Optional[str] = None, ): """Initialize instance attributes. diff --git a/sigmaiq/utils/sigmaiq/sigmaiq_utils.py b/sigmaiq/utils/sigmaiq/sigmaiq_utils.py index d21c93c..c3f4d76 100644 --- a/sigmaiq/utils/sigmaiq/sigmaiq_utils.py +++ b/sigmaiq/utils/sigmaiq/sigmaiq_utils.py @@ -9,70 +9,72 @@ def _is_v1_schema(rule_data: dict) -> bool: """Check if the rule uses v1 schema patterns.""" if not isinstance(rule_data, dict): return False - + # Check date format - date_str = rule_data.get('date') - if date_str and '/' in date_str: + date_str = rule_data.get("date") + if date_str and "/" in date_str: return True - + # Check modified format - modified_str = rule_data.get('modified') - if modified_str and '/' in modified_str: + modified_str = rule_data.get("modified") + if modified_str and "/" in modified_str: return True - + # Check tags format - tags = rule_data.get('tags', []) + tags = rule_data.get("tags", []) for tag in tags: - if any(ns in tag for ns in ['attack-', 'attack_', 'cve-', 'detection-']): + if any(ns in tag for ns in ["attack-", "attack_", "cve-", "detection-"]): return True - + # Check related field - related = rule_data.get('related', []) + related = rule_data.get("related", []) for rel in related: - if rel.get('type') == 'obsoletes': + if rel.get("type") == "obsoletes": return True - + return False + def _convert_to_v2_schema(rule_data: dict) -> dict: """Convert v1 schema rule to v2 schema.""" rule_data = rule_data.copy() - + # Convert date and modified format - if 'date' in rule_data and '/' in rule_data['date']: + if "date" in rule_data and "/" in rule_data["date"]: try: - date_obj = datetime.strptime(rule_data['date'], '%Y/%m/%d') - rule_data['date'] = date_obj.strftime('%Y-%m-%d') + date_obj = datetime.strptime(rule_data["date"], "%Y/%m/%d") + rule_data["date"] = date_obj.strftime("%Y-%m-%d") except ValueError: pass - - if 'modified' in rule_data and '/' in rule_data['modified']: + + if "modified" in rule_data and "/" in rule_data["modified"]: try: - date_obj = datetime.strptime(rule_data['modified'], '%Y/%m/%d') - rule_data['modified'] = date_obj.strftime('%Y-%m-%d') + date_obj = datetime.strptime(rule_data["modified"], "%Y/%m/%d") + rule_data["modified"] = date_obj.strftime("%Y-%m-%d") except ValueError: pass - + # Convert tags - if 'tags' in rule_data: + if "tags" in rule_data: new_tags = [] - for tag in rule_data['tags']: + for tag in rule_data["tags"]: # Convert common namespace patterns - tag = tag.replace('attack-', 'attack.') - tag = tag.replace('attack_', 'attack.') - tag = tag.replace('cve-', 'cve.') - tag = tag.replace('detection-', 'detection.') + tag = tag.replace("attack-", "attack.") + tag = tag.replace("attack_", "attack.") + tag = tag.replace("cve-", "cve.") + tag = tag.replace("detection-", "detection.") new_tags.append(tag) - rule_data['tags'] = new_tags - + rule_data["tags"] = new_tags + # Convert related field - if 'related' in rule_data: - for rel in rule_data['related']: - if rel.get('type') == 'obsoletes': - rel['type'] = 'obsolete' - + if "related" in rule_data: + for rel in rule_data["related"]: + if rel.get("type") == "obsoletes": + rel["type"] = "obsolete" + return rule_data + def create_sigma_rule_obj(sigma_rule: Union[SigmaRule, SigmaCollection, dict, str, list]): """Checks sigma_rule to ensure it's a SigmaRule or SigmaCollection object. It can also be a valid Sigma rule representation in a dict or yaml str (or list of valid dicts/yaml strs) that can be used with SigmaRule class methods to @@ -102,13 +104,14 @@ def create_sigma_rule_obj(sigma_rule: Union[SigmaRule, SigmaCollection, dict, st if isinstance(sigma_rule, dict): # Check and convert v1 schema if needed if _is_v1_schema(sigma_rule): - + sigma_rule = _convert_to_v2_schema(sigma_rule) return SigmaRule.from_dict(sigma_rule) if isinstance(sigma_rule, str): # For YAML strings, we need to parse to dict first try: import yaml + rule_dict = yaml.safe_load(sigma_rule) if _is_v1_schema(rule_dict): rule_dict = _convert_to_v2_schema(rule_dict) diff --git a/tests/test_backend_factory.py b/tests/test_backend_factory.py index 2075ed1..5983f10 100644 --- a/tests/test_backend_factory.py +++ b/tests/test_backend_factory.py @@ -50,23 +50,11 @@ def sigma_rule_dict(): "author": "AttackIQ", "date": "2023-01-01", "modified": "2023-01-02", - "tags": [ - "attack.t1003", - "attack.t1003.001", - "attack.credential_access" - ], - "logsource": { - "category": "process_creation", - "product": "windows" - }, - "detection": { - "sel": { - "CommandLine": "valueA" - }, - "condition": "sel" - }, + "tags": ["attack.t1003", "attack.t1003.001", "attack.credential_access"], + "logsource": {"category": "process_creation", "product": "windows"}, + "detection": {"sel": {"CommandLine": "valueA"}, "condition": "sel"}, "falsepositives": ["None"], - "level": "high" + "level": "high", } diff --git a/tests/test_llm_components.py b/tests/test_llm_components.py index 1bb465d..011790a 100644 --- a/tests/test_llm_components.py +++ b/tests/test_llm_components.py @@ -13,6 +13,7 @@ from sigmaiq.llm.tools.find_sigma_rule import FindSigmaRuleTool from sigmaiq.llm.tools.query_to_sigma_rule import QueryToSigmaRuleTool + class MockLLM(BaseLanguageModel): def invoke(self, *args, **kwargs): return "Mocked LLM response" @@ -48,21 +49,22 @@ def predict_messages(self, *args, **kwargs): async def apredict_messages(self, *args, **kwargs): return "Mocked async predict_messages response" + # Mock OpenAI API calls @pytest.fixture def mock_openai_create(): with patch("openai.ChatCompletion.create") as mock_create: - mock_create.return_value = { - "choices": [{"message": {"content": "Mocked OpenAI response"}}] - } + mock_create.return_value = {"choices": [{"message": {"content": "Mocked OpenAI response"}}]} yield mock_create + @pytest.fixture def mock_openai_embeddings(): with patch.object(OpenAIEmbeddings, "embed_documents") as mock_embed: mock_embed.return_value = [[0.1, 0.2, 0.3]] # Mocked embedding yield mock_embed + @pytest.fixture def mock_vector_store(): class MockVectorStore(VectorStore): @@ -81,15 +83,18 @@ def from_texts(cls, texts, embedding, metadatas=None, **kwargs): return MockVectorStore() + def test_sigma_llm_initialization(mock_openai_embeddings): sigma_llm = SigmaLLM(embedding_model=OpenAIEmbeddings()) assert sigma_llm.embedding_function is not None + def test_create_sigma_agent(mock_vector_store): mock_llm = MockLLM() agent_executor = create_sigma_agent(sigma_vectorstore=mock_vector_store, rule_creation_llm=mock_llm) assert agent_executor is not None - assert hasattr(agent_executor, 'run') + assert hasattr(agent_executor, "run") + def test_sigma_toolkit(): mock_vector_store = create_autospec(VectorStore) @@ -102,6 +107,7 @@ def test_sigma_toolkit(): assert any(isinstance(tool, FindSigmaRuleTool) for tool in tools) assert any(isinstance(tool, QueryToSigmaRuleTool) for tool in tools) + @pytest.mark.asyncio async def test_create_sigma_rule_tool(mock_openai_create, mock_vector_store): tool = CreateSigmaRuleVectorStoreTool(sigmadb=mock_vector_store, llm=MockLLM()) @@ -109,31 +115,29 @@ async def test_create_sigma_rule_tool(mock_openai_create, mock_vector_store): assert isinstance(result, str) assert "title:" in result.lower() + @pytest.mark.asyncio async def test_translate_sigma_rule_tool(mock_openai_create): tool = TranslateSigmaRuleTool() - result = await tool._arun( - sigma_rule="title: Test Rule\ndetection:\n condition: selection", - backend="splunk" - ) + result = await tool._arun(sigma_rule="title: Test Rule\ndetection:\n condition: selection", backend="splunk") assert isinstance(result, str) + @pytest.mark.asyncio async def test_find_sigma_rule_tool(mock_openai_create, mock_vector_store): tool = FindSigmaRuleTool(sigmadb=mock_vector_store, llm=MockLLM()) result = await tool._arun("Find a rule for detecting mimikatz") assert isinstance(result, str) + @pytest.mark.asyncio async def test_query_to_sigma_rule_tool(mock_openai_create): tool = QueryToSigmaRuleTool(llm=MockLLM()) - result = await tool._arun( - query="process_name=powershell.exe", - backend="splunk" - ) + result = await tool._arun(query="process_name=powershell.exe", backend="splunk") assert isinstance(result, str) assert "title:" in result.lower() + @pytest.mark.asyncio async def test_agent_execution(mock_openai_create, mock_vector_store): agent_executor = create_sigma_agent(sigma_vectorstore=mock_vector_store) @@ -141,18 +145,16 @@ async def test_agent_execution(mock_openai_create, mock_vector_store): assert isinstance(result, dict) assert "output" in result + def test_custom_openai_functions_agent_output_parser(): from sigmaiq.llm.toolkits.base import CustomOpenAIFunctionsAgentOutputParser parser = CustomOpenAIFunctionsAgentOutputParser() - + # Test parsing an AgentAction - message = AIMessage(content="", additional_kwargs={ - "function_call": { - "name": "test_function", - "arguments": '{"arg1": "value1"}' - } - }) + message = AIMessage( + content="", additional_kwargs={"function_call": {"name": "test_function", "arguments": '{"arg1": "value1"}'}} + ) result = parser.parse(message) assert isinstance(result, AgentAction) assert result.tool == "test_function" @@ -168,4 +170,5 @@ def test_custom_openai_functions_agent_output_parser(): with pytest.raises(ValueError): parser.parse("This is a string, not an AIMessage") -# Add more tests as needed for other components and edge cases \ No newline at end of file + +# Add more tests as needed for other components and edge cases diff --git a/tests/test_sigmaiq_utils.py b/tests/test_sigmaiq_utils.py index c188174..fd94ea8 100644 --- a/tests/test_sigmaiq_utils.py +++ b/tests/test_sigmaiq_utils.py @@ -1,15 +1,14 @@ import datetime import pytest import yaml -from sigmaiq.utils.sigmaiq.sigmaiq_utils import (create_sigma_rule_obj, - _is_v1_schema, - _convert_to_v2_schema) +from sigmaiq.utils.sigmaiq.sigmaiq_utils import create_sigma_rule_obj, _is_v1_schema, _convert_to_v2_schema # Existing fixtures from tests.test_backend_factory import sigma_rule, sigma_rule_yaml_str, sigma_rule_dict, sigma_collection from sigma.rule import SigmaRule from sigma.collection import SigmaCollection + # New fixtures for schema conversion tests @pytest.fixture def v1_rule_data(): @@ -17,45 +16,33 @@ def v1_rule_data(): "id": "12345678-abcd-abcd-1234-1234567890ab", "title": "Test Rule", "date": "2023/04/15", - "tags": [ - "attack.execution", - "attack_persistence", - "cve.2023.1234", - "detection.threat_hunting" - ], - "related": [ - {"type": "obsoletes", "id": "12345678-abcd-abcd-1234-1234567890ab"} - ], + "tags": ["attack.execution", "attack_persistence", "cve.2023.1234", "detection.threat_hunting"], + "related": [{"type": "obsoletes", "id": "12345678-abcd-abcd-1234-1234567890ab"}], "modified": "2023/04/15", "logsource": {"category": "process_creation", "product": "windows"}, "detection": { "selection_img": {"Image|endswith": "\\regedit.exe", "OriginalFileName": "REGEDIT.EXE"}, - "condition": "all of selection_* and not all of filter_*" - } + "condition": "all of selection_* and not all of filter_*", + }, } + @pytest.fixture def v2_rule_data(): return { "id": "12345678-abcd-abcd-1234-1234567890ab", "title": "Test Rule", "date": "2023-04-15", - "tags": [ - "attack.execution", - "attack.persistence", - "cve.2023.1234", - "detection.threat_hunting" - ], - "related": [ - {"type": "obsolete", "id": "12345678-abcd-abcd-1234-1234567890ab"} - ], + "tags": ["attack.execution", "attack.persistence", "cve.2023.1234", "detection.threat_hunting"], + "related": [{"type": "obsolete", "id": "12345678-abcd-abcd-1234-1234567890ab"}], "logsource": {"category": "process_creation", "product": "windows"}, "detection": { "selection_img": {"Image|endswith": "\\regedit.exe", "OriginalFileName": "REGEDIT.EXE"}, - "condition": "all of selection_* and not all of filter_*" - } + "condition": "all of selection_* and not all of filter_*", + }, } + # Existing tests def test_create_sigma_rule_obj_sigma_rule(sigma_rule): """Tests creating a SigmaRule object from a SigmaRule, aka just return the rule""" @@ -63,35 +50,43 @@ def test_create_sigma_rule_obj_sigma_rule(sigma_rule): print(type(sigma_rule)) assert isinstance(create_sigma_rule_obj(sigma_rule), SigmaRule) + def test_create_sigma_rule_obj_sigma_collection(sigma_collection): """Tests creating a SigmaRule object from a SigmaCollection, aka just return the collection""" assert isinstance(create_sigma_rule_obj(sigma_collection), SigmaCollection) + def test_create_sigma_rule_obj_sigma_rule_yaml_str(sigma_rule_yaml_str): """Tests creating a SigmaRule object from a valid SigmaRule YAML str""" assert isinstance(create_sigma_rule_obj(sigma_rule_yaml_str), SigmaRule) + def test_create_sigma_rule_obj_sigma_rule_dict(sigma_rule_dict): """Tests creating a SigmaRule object from a valid SigmaRule dict""" assert isinstance(create_sigma_rule_obj(sigma_rule_dict), SigmaRule) + def test_create_sigma_rule_obj_invalid_type(): """Tests creating a SigmaRule object from an invalid type""" with pytest.raises(TypeError): create_sigma_rule_obj(1) # Invalid type + def test_create_sigma_rule_obj_invalid_type_list(): """Tests creating a SigmaRule object from an invalid type list""" with pytest.raises(TypeError): create_sigma_rule_obj([1]) # Invalid type list + def test_create_sigma_rule_objsigma_rule_list(sigma_rule, sigma_rule_yaml_str): """Tests creating a SigmaRule objects from a list""" assert isinstance(create_sigma_rule_obj([sigma_rule, sigma_rule_yaml_str]), SigmaCollection) + # New schema conversion tests class TestSchemaDetection: """Tests for v1 schema detection""" + def test_v1_date_detection(self): assert _is_v1_schema({"date": "2023/04/15"}) assert not _is_v1_schema({"date": "2023-04-15"}) @@ -110,8 +105,10 @@ def test_non_dict_input(self): assert not _is_v1_schema([]) assert not _is_v1_schema("string") + class TestSchemaConversion: """Tests for v1 to v2 schema conversion""" + def test_date_conversion(self, v1_rule_data, v2_rule_data): converted = _convert_to_v2_schema(v1_rule_data) assert converted["date"] == v2_rule_data["date"] @@ -134,8 +131,10 @@ def test_missing_fields_handling(self): converted = _convert_to_v2_schema(rule_data) assert converted == rule_data # Should return unchanged + class TestSchemaConversionIntegration: """Integration tests for schema conversion with create_sigma_rule_obj""" + def test_dict_conversion(self, v1_rule_data): rule = create_sigma_rule_obj(v1_rule_data) assert isinstance(rule, SigmaRule)