Skip to content

Commit

Permalink
Moved Serializer to AfterValidator to honor exclusion of empty strings
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveMcGrath committed Feb 13, 2025
1 parent 27eda44 commit b4af3a9
Show file tree
Hide file tree
Showing 11 changed files with 346 additions and 321 deletions.
36 changes: 18 additions & 18 deletions tenable/io/sync/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class SynchronizationJobCreationError(Exception):


class SynchronizationAPI(APIEndpoint):
_path = "api/v3/data/synchronizations"
_path = 'api/v3/data/synchronizations'
_box = True

def list(
Expand Down Expand Up @@ -65,11 +65,11 @@ def list(
"""
if return_json:
params = {
"after": after,
"limit": limit,
"states": [s.upper() for s in states] if states else None,
'after': after,
'limit': limit,
'states': [s.upper() for s in states] if states else None,
}
path = f"{sync_id}/jobs" if sync_id else "jobs"
path = f'{sync_id}/jobs' if sync_id else 'jobs'
return self._get(path, params=params).jobs
return JobListIterator(
self._api, _sync_id=sync_id, _after=after, _limit=limit, _states=states
Expand All @@ -89,7 +89,7 @@ def get(self, sync_id: str, job_id: UUID | str) -> Job:
Job:
The job resource object
"""
response = self._get(f"{sync_id}/jobs/{job_id}")
response = self._get(f'{sync_id}/jobs/{job_id}')
return Job(**response)

def create(
Expand Down Expand Up @@ -133,19 +133,19 @@ def create(
>>> print(job.counters)
"""
payload = {
"active_state_timeout_seconds": submission_timeout,
"expire_after_seconds": job_lifetime,
'active_state_timeout_seconds': submission_timeout,
'expire_after_seconds': job_lifetime,
}
resp = self._post(
path=f"{sync_id}/jobs",
path=f'{sync_id}/jobs',
json=payload,
)
if return_json:
return resp
if resp.success:
return JobManager(api=self._api, sync_id=sync_id, job_uuid=resp.id)
raise SynchronizationJobCreationError(
"Could not create the job, received {dict(resp)}"
'Could not create the job, received {dict(resp)}'
)

def upload_chunk(
Expand Down Expand Up @@ -182,8 +182,8 @@ def upload_chunk(
"""
data = SyncChunkObjects(objects=objects)
return self._post(
f"{sync_id}/jobs/{job_id}/chunks/{chunk_id}",
json=data.model_dump(exclude_none=True, mode="json"),
f'{sync_id}/jobs/{job_id}/chunks/{chunk_id}',
json=data.model_dump(exclude_none=True, mode='json'),
).success

def delete(self, sync_id: str, job_id: UUID | str) -> None:
Expand All @@ -203,7 +203,7 @@ def delete(self, sync_id: str, job_id: UUID | str) -> None:
Example:
>>> tvm.sync.delete('example_id', '12345678-1234-1234-1234-123456789012')
"""
self._delete(f"{sync_id}/jobs/{job_id}")
self._delete(f'{sync_id}/jobs/{job_id}')

def submit(self, sync_id: str, job_id: str, num_chunks: int) -> bool:
"""
Expand All @@ -229,15 +229,15 @@ def submit(self, sync_id: str, job_id: str, num_chunks: int) -> bool:
... )
"""
return self._post(
path=f"{sync_id}/jobs/{job_id}/_submit",
json={"number_of_chunks": num_chunks},
path=f'{sync_id}/jobs/{job_id}/_submit',
json={'number_of_chunks': num_chunks},
).success

def audit_logs(
self,
sync_id: str,
job_id: UUID | str,
levels: list[Literal["INFO", "WARN", "ERROR"]] | None = None,
levels: list[Literal['INFO', 'WARN', 'ERROR']] | None = None,
after: str | None = None,
limit: int | None = 25,
return_json: bool = False,
Expand Down Expand Up @@ -270,8 +270,8 @@ def audit_logs(
"""
if return_json:
resp = self._get(
f"{sync_id}/jobs/{job_id}/audits",
params={"after": after, "levels": levels},
f'{sync_id}/jobs/{job_id}/audits',
params={'after': after, 'levels': levels},
)
return [LogLine(**i) for i in resp.audits]
return JobLogIterator(
Expand Down
4 changes: 2 additions & 2 deletions tenable/io/sync/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class JobListIterator(APIIterator):
_api: "TenableIO"
_api: 'TenableIO'
_sync_id: str | None = None
_limit: int = 25
_states: list[str] | None = None
Expand All @@ -34,7 +34,7 @@ def _get_page(self):


class JobLogIterator(APIIterator):
_api: "TenableIO"
_api: 'TenableIO'
_sync_id: str
_job_id: str
_levels: list[str] | None = None
Expand Down
44 changes: 22 additions & 22 deletions tenable/io/sync/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,22 @@ class JobManager:

sync_id: str
uuid: UUID
_api: "TenableIO"
_api: 'TenableIO'
_log: logging.Logger
chunk_id: int
max_retries: int = 3
terminate_on_failure: bool = True
object_map: dict[str, Any] = {
"cve-finding": CVEFinding,
"device-asset": DeviceAsset,
'cve-finding': CVEFinding,
'device-asset': DeviceAsset,
}
counters: dict[str, dict[str, int]]
cache_max: int = 1000
_cache: list["BaseModel"]
_cache: list['BaseModel']

def __init__(
self,
api: "TenableIO",
api: 'TenableIO',
sync_id: str,
job_uuid: UUID,
terminate_on_failue: bool = True,
Expand All @@ -107,9 +107,9 @@ def __init__(
self.chunks = {}
self.chunk_id = 1
self.terminate_on_failure = terminate_on_failue
self._log = logging.getLogger(f"JobManager[{sync_id}:{job_uuid}]")
self._log.info(f"Starting management of job {sync_id} :: {job_uuid}")
self.counters = {k: {"accepted": 0} for k in self.object_map}
self._log = logging.getLogger(f'JobManager[{sync_id}:{job_uuid}]')
self._log.info(f'Starting management of job {sync_id} :: {job_uuid}')
self.counters = {k: {'accepted': 0} for k in self.object_map}

def __enter__(self):
return self
Expand All @@ -123,14 +123,14 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
# Otherwise we will check to see if we should be terminating
# the job and perform the appropriate action.
else:
self._log.error(f"failed with {exc_type}:{exc_value}")
self._log.error(f'failed with {exc_type}:{exc_value}')
if self.terminate_on_failure:
self.terminate()

def add(
self,
object: dict[str, Any],
object_type: Literal["device-asset", "cve-finding"],
object_type: Literal['device-asset', 'cve-finding'],
) -> None:
"""
Adds an object to the job for processing.
Expand All @@ -140,7 +140,7 @@ def add(
object_type: The type of sync object that is being added to the job.
"""
self._cache.append(self.object_map[object_type](**object))
self.counters[object_type]["accepted"] += 1
self.counters[object_type]['accepted'] += 1

# If the number of items added to the cache exceeds the cache limits, then we
# should upload a new chunk of data and flush the cache.
Expand All @@ -159,8 +159,8 @@ def upload(self) -> None:
resp = None
while retry_counter < self.max_retries and not resp:
self._log.info(
f"Uploading chunk={self.chunk_id} "
f"attempt {retry_counter + 1} of {self.max_retries}"
f'Uploading chunk={self.chunk_id} '
f'attempt {retry_counter + 1} of {self.max_retries}'
)
try:
resp = self._api.sync.upload_chunk(
Expand All @@ -174,20 +174,20 @@ def upload(self) -> None:
# raise job termination error detailing why the job was terminated.
self.terminate()
raise SyncJobTerminated(
f"Job {self.uuid} has been terminated due to server error. "
f"The upstream server error is {err.response.content}"
f'Job {self.uuid} has been terminated due to server error. '
f'The upstream server error is {err.response.content}'
) from err
except APIError as _:
# If we receive any other API error, we will log the error and wait
# for the retry delay to expire before attempting again.
self._log.exception("Upstream API Error Occurred.")
self._log.exception('Upstream API Error Occurred.')
time.sleep(self._retry_delay)
except Exception as err:
# If any other exception is fired then we will terminate the job and
# raise a job termination error.
self.terminate()
raise SyncJobTerminated(
"An unknown error has occurred and caused the sync job to terminate"
'An unknown error has occurred and caused the sync job to terminate'
) from err

# We always want to increment the retry counter by one no matter what.
Expand All @@ -199,13 +199,13 @@ def upload(self) -> None:
# failed.
self.terminate()
raise SyncJobTerminated(
f"Chunk {self.chunk_id} failed to upload "
"and caused the job to be terminated."
f'Chunk {self.chunk_id} failed to upload '
'and caused the job to be terminated.'
)
elif not resp:
# If there isn't any response, but the automatic termination of the job is
# not enabled, then raise a job failure error instead.
raise SyncJobFailure(f"Chunk {self.chunk_id} failed to upload")
raise SyncJobFailure(f'Chunk {self.chunk_id} failed to upload')

# Increment the chunk counter and flush the cache.
self.chunk_id += 1
Expand All @@ -216,7 +216,7 @@ def terminate(self):
Terminates the current job, abandoning any data that
has already been uploaded.
"""
self._log.info(f"terminating Job {self.sync_id} :: {self.uuid}")
self._log.info(f'terminating Job {self.sync_id} :: {self.uuid}')
self._api.sync.delete(self.sync_id, self.uuid)

def submit(self) -> bool:
Expand All @@ -227,7 +227,7 @@ def submit(self) -> bool:
# before submitting the job.
if len(self._cache) > 0:
self.upload()
self._log.info(f"Submitting {self.sync_id} :: {self.uuid} for processing.")
self._log.info(f'Submitting {self.sync_id} :: {self.uuid} for processing.')

# Submit the job
return self._api.sync.submit(
Expand Down
40 changes: 19 additions & 21 deletions tenable/io/sync/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from typing import TYPE_CHECKING, Any

from pydantic import (
BaseModel as PydanticBaseModel,
)
from pydantic import (
AfterValidator,
BeforeValidator,
ConfigDict,
Field,
Expand All @@ -14,18 +12,20 @@
WrapValidator,
model_validator,
)
from pydantic.functional_serializers import PlainSerializer
from pydantic import (
BaseModel as PydanticBaseModel,
)
from restfly.utils import trunc
from typing_extensions import Annotated

if TYPE_CHECKING:
from pydantic_core.core_schema import ValidationInfo, ValidatorFunctionWrapHandler

logger = logging.getLogger("tenable.io.sync.schema")
logger = logging.getLogger('tenable.io.sync.schema')


def trunc_str(
v: Any, handler: "ValidatorFunctionWrapHandler", info: "ValidationInfo"
v: Any, handler: 'ValidatorFunctionWrapHandler', info: 'ValidationInfo'
) -> str:
"""
Ensure that a constrained string does not extend beyond the length limit.
Expand All @@ -39,22 +39,22 @@ def trunc_str(
return handler(v)
except ValidationError as err:
errors = err.errors()
if len(errors) > 1 or errors[0]["type"] != "string_too_long":
if len(errors) > 1 or errors[0]['type'] != 'string_too_long':
raise err
model = info.config["title"]
model = info.config['title']
name = info.field_name
limit = errors[0]["ctx"]["max_length"]
limit = errors[0]['ctx']['max_length']
value = trunc(v, limit)
warnings.warn(
f"{model}.{name} has been truncated to {limit} chars, originally {len(str(v))}",
f'{model}.{name} has been truncated to {limit} chars, originally {len(str(v))}',
SyntaxWarning,
stacklevel=0,
)
return value


def trunc_list(
v: Any, handler: "ValidatorFunctionWrapHandler", info: "ValidationInfo"
v: Any, handler: 'ValidatorFunctionWrapHandler', info: 'ValidationInfo'
) -> list:
"""
Ensure that a constrained list does not extend beyond the length limit. Please
Expand All @@ -65,14 +65,14 @@ def trunc_list(
return handler(v)
except ValidationError as err:
errors = err.errors()
if len(errors) > 1 or errors[0]["type"] != "too_long":
if len(errors) > 1 or errors[0]['type'] != 'too_long':
raise err
model = info.config["title"]
model = info.config['title']
name = info.field_name
limit = errors[0]["ctx"]["max_length"]
limit = errors[0]['ctx']['max_length']
value = v[:limit]
warnings.warn(
f"{model}.{name} has been truncated to {limit} items, originally {len(v)}",
f'{model}.{name} has been truncated to {limit} items, originally {len(v)}',
SyntaxWarning,
stacklevel=0,
)
Expand All @@ -91,9 +91,7 @@ def upper_if_exist(value: Any, default=None) -> Any:


UpperCaseStr = BeforeValidator(lambda v: str(v).upper() if v else None)
UniqueListSerializer = PlainSerializer(
lambda v: list(set(v)) if v and len(v) > 0 else None
)
UniqueList = AfterValidator(lambda v: list(set(v)) if v and len(v) > 0 else None)
TruncListValidator = WrapValidator(trunc_list)

intpos = Annotated[int, Field(gt=0)]
Expand All @@ -111,9 +109,9 @@ def upper_if_exist(value: Any, default=None) -> Any:


class BaseModel(PydanticBaseModel):
model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra='forbid')

@model_validator(mode="before")
@model_validator(mode='before')
@classmethod
def empty_list_to_none(cls, data):
if isinstance(data, dict):
Expand All @@ -126,7 +124,7 @@ class CustomAttribute(BaseModel):
value: str512 | None

def __hash__(self) -> int:
return hash(f"{self.name}:{str(self.value)}")
return hash(f'{self.name}:{str(self.value)}')

def __eq__(self, other: Any) -> bool:
return self.name == other.name and self.value == other.value
Expand Down
Loading

0 comments on commit b4af3a9

Please sign in to comment.