Skip to content

Commit

Permalink
Remove use of .value on enums in code
Browse files Browse the repository at this point in the history
now that they have been more
strictly defined as IntEnum or
StrEnum.
This has not yet been tested.
  • Loading branch information
pyth0n1c committed Dec 4, 2024
1 parent 0999270 commit 31f46a2
Show file tree
Hide file tree
Showing 16 changed files with 155 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def test_detection(self, detection: Detection) -> None:
self.format_pbar_string(
TestReportingType.GROUP,
test_group.name,
FinalTestingStates.SKIP.value,
FinalTestingStates.SKIP,
start_time=time.time(),
set_pbar=False,
)
Expand Down Expand Up @@ -483,7 +483,7 @@ def test_detection(self, detection: Detection) -> None:
self.format_pbar_string(
TestReportingType.GROUP,
test_group.name,
TestingStates.DONE_GROUP.value,
TestingStates.DONE_GROUP,
start_time=setup_results.start_time,
set_pbar=False,
)
Expand All @@ -504,7 +504,7 @@ def setup_test_group(self, test_group: TestGroup) -> SetupTestGroupResults:
self.format_pbar_string(
TestReportingType.GROUP,
test_group.name,
TestingStates.BEGINNING_GROUP.value,
TestingStates.BEGINNING_GROUP,
start_time=setup_start_time
)
# https://github.com/WoLpH/python-progressbar/issues/164
Expand Down Expand Up @@ -544,7 +544,7 @@ def cleanup_test_group(
self.format_pbar_string(
TestReportingType.GROUP,
test_group.name,
TestingStates.DELETING.value,
TestingStates.DELETING,
start_time=test_group_start_time,
)

Expand Down Expand Up @@ -632,7 +632,7 @@ def execute_unit_test(
self.format_pbar_string(
TestReportingType.UNIT,
f"{detection.name}:{test.name}",
FinalTestingStates.SKIP.value,
FinalTestingStates.SKIP,
start_time=test_start_time,
set_pbar=False,
)
Expand Down Expand Up @@ -664,7 +664,7 @@ def execute_unit_test(
self.format_pbar_string(
TestReportingType.UNIT,
f"{detection.name}:{test.name}",
FinalTestingStates.ERROR.value,
FinalTestingStates.ERROR,
start_time=test_start_time,
set_pbar=False,
)
Expand Down Expand Up @@ -724,7 +724,7 @@ def execute_unit_test(
res = "ERROR"
link = detection.search
else:
res = test.result.status.value.upper() # type: ignore
res = test.result.status.upper() # type: ignore
link = test.result.get_summary_dict()["sid_link"]

self.format_pbar_string(
Expand Down Expand Up @@ -755,7 +755,7 @@ def execute_unit_test(
self.format_pbar_string(
TestReportingType.UNIT,
f"{detection.name}:{test.name}",
FinalTestingStates.PASS.value,
FinalTestingStates.PASS,
start_time=test_start_time,
set_pbar=False,
)
Expand All @@ -766,7 +766,7 @@ def execute_unit_test(
self.format_pbar_string(
TestReportingType.UNIT,
f"{detection.name}:{test.name}",
FinalTestingStates.SKIP.value,
FinalTestingStates.SKIP,
start_time=test_start_time,
set_pbar=False,
)
Expand All @@ -777,7 +777,7 @@ def execute_unit_test(
self.format_pbar_string(
TestReportingType.UNIT,
f"{detection.name}:{test.name}",
FinalTestingStates.FAIL.value,
FinalTestingStates.FAIL,
start_time=test_start_time,
set_pbar=False,
)
Expand All @@ -788,7 +788,7 @@ def execute_unit_test(
self.format_pbar_string(
TestReportingType.UNIT,
f"{detection.name}:{test.name}",
FinalTestingStates.ERROR.value,
FinalTestingStates.ERROR,
start_time=test_start_time,
set_pbar=False,
)
Expand Down Expand Up @@ -821,7 +821,7 @@ def execute_integration_test(
test_start_time = time.time()

# First, check to see if the test should be skipped (Hunting or Correlation)
if detection.type in [AnalyticsType.Hunting.value, AnalyticsType.Correlation.value]:
if detection.type in [AnalyticsType.Hunting, AnalyticsType.Correlation]:
test.skip(
f"TEST SKIPPED: detection is type {detection.type} and cannot be integration "
"tested at this time"
Expand All @@ -843,11 +843,11 @@ def execute_integration_test(
# Determine the reporting state (we should only encounter SKIP/FAIL/ERROR)
state: str
if test.result.status == TestResultStatus.SKIP:
state = FinalTestingStates.SKIP.value
state = FinalTestingStates.SKIP
elif test.result.status == TestResultStatus.FAIL:
state = FinalTestingStates.FAIL.value
state = FinalTestingStates.FAIL
elif test.result.status == TestResultStatus.ERROR:
state = FinalTestingStates.ERROR.value
state = FinalTestingStates.ERROR
else:
raise ValueError(
f"Status for (integration) '{detection.name}:{test.name}' was preemptively set"
Expand Down Expand Up @@ -891,7 +891,7 @@ def execute_integration_test(
self.format_pbar_string(
TestReportingType.INTEGRATION,
f"{detection.name}:{test.name}",
FinalTestingStates.FAIL.value,
FinalTestingStates.FAIL,
start_time=test_start_time,
set_pbar=False,
)
Expand Down Expand Up @@ -935,7 +935,7 @@ def execute_integration_test(
if test.result is None:
res = "ERROR"
else:
res = test.result.status.value.upper() # type: ignore
res = test.result.status.upper() # type: ignore

# Get the link to the saved search in this specific instance
link = f"https://{self.infrastructure.instance_address}:{self.infrastructure.web_ui_port}"
Expand Down Expand Up @@ -968,7 +968,7 @@ def execute_integration_test(
self.format_pbar_string(
TestReportingType.INTEGRATION,
f"{detection.name}:{test.name}",
FinalTestingStates.PASS.value,
FinalTestingStates.PASS,
start_time=test_start_time,
set_pbar=False,
)
Expand All @@ -979,7 +979,7 @@ def execute_integration_test(
self.format_pbar_string(
TestReportingType.INTEGRATION,
f"{detection.name}:{test.name}",
FinalTestingStates.SKIP.value,
FinalTestingStates.SKIP,
start_time=test_start_time,
set_pbar=False,
)
Expand All @@ -990,7 +990,7 @@ def execute_integration_test(
self.format_pbar_string(
TestReportingType.INTEGRATION,
f"{detection.name}:{test.name}",
FinalTestingStates.FAIL.value,
FinalTestingStates.FAIL,
start_time=test_start_time,
set_pbar=False,
)
Expand All @@ -1001,7 +1001,7 @@ def execute_integration_test(
self.format_pbar_string(
TestReportingType.INTEGRATION,
f"{detection.name}:{test.name}",
FinalTestingStates.ERROR.value,
FinalTestingStates.ERROR,
start_time=test_start_time,
set_pbar=False,
)
Expand Down Expand Up @@ -1077,7 +1077,7 @@ def retry_search_until_timeout(
self.format_pbar_string(
TestReportingType.UNIT,
f"{detection.name}:{test.name}",
TestingStates.PROCESSING.value,
TestingStates.PROCESSING,
start_time=start_time
)

Expand All @@ -1086,7 +1086,7 @@ def retry_search_until_timeout(
self.format_pbar_string(
TestReportingType.UNIT,
f"{detection.name}:{test.name}",
TestingStates.SEARCHING.value,
TestingStates.SEARCHING,
start_time=start_time,
)

Expand Down Expand Up @@ -1289,7 +1289,7 @@ def replay_attack_data_file(
self.format_pbar_string(
TestReportingType.GROUP,
test_group.name,
TestingStates.DOWNLOADING.value,
TestingStates.DOWNLOADING,
start_time=test_group_start_time
)

Expand All @@ -1307,7 +1307,7 @@ def replay_attack_data_file(
self.format_pbar_string(
TestReportingType.GROUP,
test_group.name,
TestingStates.REPLAYING.value,
TestingStates.REPLAYING,
start_time=test_group_start_time
)

Expand Down
12 changes: 6 additions & 6 deletions contentctl/actions/detection_testing/progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import time
from enum import Enum
from enum import StrEnum
from tqdm import tqdm
import datetime


class TestReportingType(str, Enum):
class TestReportingType(StrEnum):
"""
5-char identifiers for the type of testing being reported on
"""
Expand All @@ -21,7 +21,7 @@ class TestReportingType(str, Enum):
INTEGRATION = "INTEG"


class TestingStates(str, Enum):
class TestingStates(StrEnum):
"""
Defined testing states
"""
Expand All @@ -40,10 +40,10 @@ class TestingStates(str, Enum):


# the longest length of any state
LONGEST_STATE = max(len(w.value) for w in TestingStates)
LONGEST_STATE = max(len(w) for w in TestingStates)


class FinalTestingStates(str, Enum):
class FinalTestingStates(StrEnum):
"""
The possible final states for a test (for pbar reporting)
"""
Expand Down Expand Up @@ -82,7 +82,7 @@ def format_pbar_string(
:returns: a formatted string for use w/ pbar
"""
# Extract and ljust our various fields
field_one = test_reporting_type.value
field_one = test_reporting_type
field_two = test_name.ljust(MAX_TEST_NAME_LENGTH)
field_three = state.ljust(LONGEST_STATE)
field_four = datetime.timedelta(seconds=round(time.time() - start_time))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def getSummaryObject(
total_skipped += 1

# Aggregate production status metrics
if detection.status == DetectionStatus.production.value: # type: ignore
if detection.status == DetectionStatus.production:
total_production += 1
elif detection.status == DetectionStatus.experimental.value: # type: ignore
elif detection.status == DetectionStatus.experimental:
total_experimental += 1
elif detection.status == DetectionStatus.deprecated.value: # type: ignore
elif detection.status == DetectionStatus.deprecated:
total_deprecated += 1

# Check if the detection is manual_test
Expand Down Expand Up @@ -178,7 +178,7 @@ def getSummaryObject(
# Construct and return the larger results dict
result_dict = {
"summary": {
"mode": self.config.getModeName(),
"mode": self.config.mode.mode_name,
"enable_integration_testing": self.config.enable_integration_testing,
"success": overall_success,
"total_detections": total_detections,
Expand Down
9 changes: 4 additions & 5 deletions contentctl/actions/test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import List

from contentctl.objects.config import test_common
from contentctl.objects.config import test_common, Selected, Changes
from contentctl.objects.enums import DetectionTestingMode, DetectionStatus, AnalyticsType
from contentctl.objects.detection import Detection

Expand Down Expand Up @@ -78,19 +78,18 @@ def execute(self, input_dto: TestInputDto) -> bool:
input_dto=manager_input_dto, output_dto=output_dto
)

mode = input_dto.config.getModeName()
if len(input_dto.detections) == 0:
print(
f"With Detection Testing Mode '{mode}', there were [0] detections found to test."
f"With Detection Testing Mode '{input_dto.config.mode.mode_name}', there were [0] detections found to test."
"\nAs such, we will quit immediately."
)
# Directly call stop so that the summary.yml will be generated. Of course it will not
# have any test results, but we still want it to contain a summary showing that now
# detections were tested.
file.stop()
else:
print(f"MODE: [{mode}] - Test [{len(input_dto.detections)}] detections")
if mode in [DetectionTestingMode.changes.value, DetectionTestingMode.selected.value]:
print(f"MODE: [{input_dto.config.mode.mode_name}] - Test [{len(input_dto.detections)}] detections")
if isinstance(input_dto.config.mode, Selected) or isinstance(input_dto.config.mode, Changes):
files_string = '\n- '.join(
[str(pathlib.Path(detection.file_path).relative_to(input_dto.config.path)) for detection in input_dto.detections]
)
Expand Down
Loading

0 comments on commit 31f46a2

Please sign in to comment.