Skip to content

Commit

Permalink
Some generic deleter improvements and fixed autoserialization with "t…
Browse files Browse the repository at this point in the history
…yped" list, creating a simple mechanic to restore original types
  • Loading branch information
dkmstr committed Jul 8, 2024
1 parent 1cb96fd commit 525e9b9
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def test_deletion_fails_add(self) -> None:
self.assertEqual(self.count_entries_on_storage(deferred_deletion.TO_DELETE_GROUP), 1)
self.assertEqual(self.count_entries_on_storage(deferred_deletion.DELETING_GROUP), 0)
# Test that MAX_TOTAL_RETRIES works fine
deferred_deletion.MAX_TOTAL_RETRIES = 2
deferred_deletion.MAX_RETRAYABLE_ERROR_RETRIES = 2
# reset last_check, or it will not retry
self.set_last_check_expired()
job.run()
Expand Down Expand Up @@ -448,7 +448,7 @@ def test_deletion_fails_is_deleted(self) -> None:
self.assertEqual(self.count_entries_on_storage(deferred_deletion.TO_DELETE_GROUP), 0)
self.assertEqual(self.count_entries_on_storage(deferred_deletion.DELETING_GROUP), 1)
# Test that MAX_TOTAL_RETRIES works fine
deferred_deletion.MAX_TOTAL_RETRIES = 2
deferred_deletion.MAX_RETRAYABLE_ERROR_RETRIES = 2
# reset last_check, or it will not retry
self.set_last_check_expired()
job.run()
Expand Down Expand Up @@ -575,7 +575,7 @@ def _running(*args: typing.Any, **kwargs: typing.Any) -> bool:

def test_stop_retry_stop(self) -> None:
deferred_deletion.RETRIES_TO_RETRY = 2
deferred_deletion.MAX_TOTAL_RETRIES = 4
deferred_deletion.MAX_RETRAYABLE_ERROR_RETRIES = 4

with self.patch_for_worker(
is_running=helpers.returns_true,
Expand Down Expand Up @@ -659,7 +659,7 @@ def test_stop_retry_stop(self) -> None:

def test_delete_retry_delete(self) -> None:
deferred_deletion.RETRIES_TO_RETRY = 2
deferred_deletion.MAX_TOTAL_RETRIES = 4
deferred_deletion.MAX_RETRAYABLE_ERROR_RETRIES = 4

with self.patch_for_worker(
is_running=helpers.returns_true,
Expand Down
4 changes: 2 additions & 2 deletions server/src/uds/REST/methods/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

# Enclosed methods under /stats path
POINTS = 70
SINCE = 180 # Days, if higer values used, ensure mysql/mariadb has a bigger sort buffer
SINCE = 90 # Days, if higer values used, ensure mysql/mariadb has a bigger sort buffer
USE_MAX = True
CACHE_TIME = 60 * 60 # 1 hour

Expand Down Expand Up @@ -91,7 +91,7 @@ def get_servicepools_counters(
val = [
{
'stamp': x.stamp,
'value': (x.sum // x.count if x.count > 0 else 0) if not USE_MAX else x.max,
'value': (x.sum / x.count if x.count > 0 else 0) if not USE_MAX else x.max,
}
for x in stats
]
Expand Down
25 changes: 11 additions & 14 deletions server/src/uds/core/services/generics/dynamic/publication.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class DynamicPublication(services.Publication, autoserializable.AutoSerializable

_name = autoserializable.StringField(default='')
_vmid = autoserializable.StringField(default='')
_queue = autoserializable.ListField[Operation]()
_queue = autoserializable.ListField[Operation](cast=Operation.from_int)
_reason = autoserializable.StringField(default='')
_is_flagged_for_destroy = autoserializable.BoolField(default=False)

Expand All @@ -83,12 +83,12 @@ def _reset_checks_counter(self) -> None:
data['exec_count'] = 0

@typing.final
def _inc_checks_counter(self, info: typing.Optional[str] = None) -> typing.Optional[types.states.TaskState]:
def _inc_checks_counter(self, op: Operation) -> typing.Optional[types.states.TaskState]:
with self.storage.as_dict() as data:
count = data.get('exec_count', 0) + 1
data['exec_count'] = count
if count > self.max_state_checks:
return self._error(f'Max checks reached on {info or "unknown"}')
return self._error(f'Max checks reached on {op}')
return None

@typing.final
Expand Down Expand Up @@ -132,7 +132,7 @@ def _error(self, reason: typing.Union[str, Exception]) -> types.states.TaskState
Returns:
State.ERROR, so we can do "return self._error(reason)"
"""
self._error_debug_info = self._debug(repr(reason))
self._error_debug_info = self._debug(f'{repr(reason)} {getattr(reason, "__backtrace__", "")}')
reason = str(reason)
logger.error(reason)

Expand Down Expand Up @@ -175,7 +175,7 @@ def _execute_queue(self) -> types.states.TaskState:
# This is a retryable error, so we will retry later
return self.retry_later()
except Exception as e:
logger.exception('Unexpected FixedUserService exception: %s', e)
logger.debug('Exception on %s: %s', op, e, exc_info=True)
return self._error(str(e))

@typing.final
Expand All @@ -202,11 +202,11 @@ def generate_name(self) -> str:
else:
# Get the service pool name, and remove all {} macros
name = self.servicepool_name()

return self.service().sanitized_name(f'UDS-Pub-{name}-v{self.revision()}')

def generate_annotation(self) -> str:
return (f'UDS publication for {self.servicepool_name()} created on {time.strftime("%Y-%m-%d %H:%M:%S")}')
return f'UDS publication for {self.servicepool_name()} created on {time.strftime("%Y-%m-%d %H:%M:%S")}'

def check_space(self) -> bool:
"""
Expand Down Expand Up @@ -245,7 +245,7 @@ def check_state(self) -> types.states.TaskState:

if op != Operation.WAIT:
# All operations except WAIT will check against checks counter
counter_state = self._inc_checks_counter(self._op2str(op))
counter_state = self._inc_checks_counter(op)
if counter_state is not None:
return counter_state # Error, Finished or None

Expand All @@ -272,6 +272,7 @@ def check_state(self) -> types.states.TaskState:
# And it has not been removed from the queue
return types.states.TaskState.RUNNING
except Exception as e:
logger.debug('Exception on %s: %s', op, e, exc_info=True)
return self._error(e)

@typing.final
Expand Down Expand Up @@ -530,12 +531,8 @@ def op_unsupported(self) -> None:
def op_unsupported_checker(self) -> types.states.TaskState:
raise Exception('Operation not defined')

@staticmethod
def _op2str(op: Operation) -> str:
return op.name

def _debug(self, txt: str) -> str:
msg = f'Queue at {txt} for {self._name}: {self._queue}, vmid:{self._vmid}'
msg = f'{txt} on {self._name}: {self._queue}, vmid:{self._vmid}'
logger.debug(
msg,
)
Expand Down
9 changes: 6 additions & 3 deletions server/src/uds/core/services/generics/dynamic/userservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ class DynamicUserService(services.UserService, autoserializable.AutoSerializable
_ip = autoserializable.StringField(default='')
_vmid = autoserializable.StringField(default='')
_reason = autoserializable.StringField(default='')
_queue = autoserializable.ListField[types.services.Operation]() # Default is empty list
# cast is used to ensure that when data is reloaded, it's casted to the correct type
_queue = autoserializable.ListField[types.services.Operation](cast=types.services.Operation.from_int)
_is_flagged_for_destroy = autoserializable.BoolField(default=False)

# Extra info, not serializable, to keep information in case of exception and debug it
Expand Down Expand Up @@ -313,7 +314,7 @@ def get_vmname(self) -> str:
return consts.NO_MORE_NAMES

return self.service().sanitized_name(f'UDS_{name}') # Default implementation

# overridable, to allow receiving notifications from, for example, services
def notify(self, message: str, data: typing.Any = None) -> None:
pass
Expand Down Expand Up @@ -463,7 +464,9 @@ def check_state(self) -> types.states.TaskState:
if state == types.states.TaskState.FINISHED:
# Remove finished operation from queue
top_op = self._queue.pop(0)
if top_op != types.services.Operation.RETRY: # Inserted if a retrayable error occurs on execution queue
if (
top_op != types.services.Operation.RETRY
): # Inserted if a retrayable error occurs on execution queue
self._reset_retries_counter()
return self._execute_queue()

Expand Down
53 changes: 43 additions & 10 deletions server/src/uds/core/util/autoserializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class _MarshalInfo:
- 2 bytes -> name length, little endian
- 2 bytes -> type name length, little endian
- 4 bytes -> data length, little endian
(Previous is defined by PACKED_LENGHS struct)
- n bytes -> name
- n bytes -> type name
Expand Down Expand Up @@ -212,10 +212,10 @@ def __get__(
instance {SerializableFields} -- Instance of class with field
"""
if hasattr(instance, '_fields'):
if hasattr(instance, '_fields'):
if self.name in getattr(instance, '_fields'):
return getattr(instance, '_fields')[self.name]

if self.default is None:
raise AttributeError(f"Field {self.name} is not set")
# Set default using setter
Expand Down Expand Up @@ -310,16 +310,26 @@ def unmarshal(self, instance: 'AutoSerializable', data: bytes) -> None:

class ListField(_SerializableField[list[T]], list[T]):
"""List field
Args:
default: Default value for the field. Can be a list or a callable that returns a list.
cast: Optional function to cast the values of the list to the desired type. If not provided, the values will be "deserialized" as they are. (see notes)
Note:
All elements in the list must be serializable in JSON, but can be of different types.
In case of serilization of enumerations, they will be serialized as integers or strings.
(Take into account this when using enumerations in lists. The values will be compatible, but not the types)
"""

_cast: typing.Optional[typing.Callable[[typing.Any], T]]

def __init__(
self,
default: typing.Union[list[T], collections.abc.Callable[[], list[T]]] = lambda: [],
cast: typing.Optional[typing.Callable[[typing.Any], T]] = None,
):
super().__init__(list, default)
self._cast = cast

def marshal(self, instance: 'AutoSerializable') -> bytes:
# \x01 is the version of this field marshal format, so we can change it in the future
Expand All @@ -328,19 +338,31 @@ def marshal(self, instance: 'AutoSerializable') -> bytes:
def unmarshal(self, instance: 'AutoSerializable', data: bytes) -> None:
if data[0] != 1:
raise ValueError('Invalid list data')
self.__set__(instance, json.loads(data[1:]))

self.__set__(
instance, [self._cast(i) for i in json.loads(data[1:])] if self._cast else json.loads(data[1:])
)


class DictField(_SerializableField[dict[T, V]], dict[T, V]):
"""Dict field
Args:
default: Default value for the field. Can be a dict or a callable that returns a dict.
cast: Optional function to cast the values of the dict to the desired type. If not provided, the values will be "deserialized" as they are. (see notes)
Note:
All elements in the dict must be serializable.
Note that due to the use of json as serialization format, keys Will be converted to strings.
Also, values of enumerations will be serialized as integers or strings.
"""

_cast: typing.Optional[typing.Callable[[T, V], tuple[T, V]]]

def __init__(
self,
default: typing.Union[dict[T, V], collections.abc.Callable[[], dict[T, V]]] = lambda: {},
cast: typing.Optional[typing.Callable[[typing.Any], tuple[T, V]]] = None,
):
super().__init__(dict, default)

Expand All @@ -351,7 +373,10 @@ def marshal(self, instance: 'AutoSerializable') -> bytes:
def unmarshal(self, instance: 'AutoSerializable', data: bytes) -> None:
if data[0] != 1:
raise ValueError('Invalid dict data')
self.__set__(instance, json.loads(data[1:]))
self.__set__(
instance,
dict(self._cast(k, v) for k, v in json.loads(data[1:])) if self._cast else json.loads(data[1:]),
)


class ObjectField(_SerializableField[T]):
Expand Down Expand Up @@ -475,7 +500,7 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter):
"""

_fields: dict[str, typing.Any]

serialization_version: int = 0 # So autoserializable classes can keep their version if needed

def _autoserializable_fields(self) -> collections.abc.Iterator[tuple[str, _SerializableField[typing.Any]]]:
Expand Down Expand Up @@ -546,7 +571,11 @@ def marshal(self) -> bytes:
# Calculate checksum
checksum = zlib.crc32(data)
# Compose header, that is V1_HEADER + checksum (4 bytes, big endian)
header = HEADER_BASE + self.serialization_version.to_bytes(VERSION_SIZE, 'big') + checksum.to_bytes(CRC_SIZE, 'big')
header = (
HEADER_BASE
+ self.serialization_version.to_bytes(VERSION_SIZE, 'big')
+ checksum.to_bytes(CRC_SIZE, 'big')
)
# Return data processed with header
return header + self.process_data(header, data)

Expand All @@ -559,9 +588,13 @@ def unmarshal(self, data: bytes) -> None:

header = data[: len(HEADER_BASE) + VERSION_SIZE + CRC_SIZE]
# extract version
self._serialization_version = int.from_bytes(header[len(HEADER_BASE) : len(HEADER_BASE) + VERSION_SIZE], 'big')
self._serialization_version = int.from_bytes(
header[len(HEADER_BASE) : len(HEADER_BASE) + VERSION_SIZE], 'big'
)
# Extract checksum
checksum = int.from_bytes(header[len(HEADER_BASE) + VERSION_SIZE : len(HEADER_BASE) + VERSION_SIZE + CRC_SIZE], 'big')
checksum = int.from_bytes(
header[len(HEADER_BASE) + VERSION_SIZE : len(HEADER_BASE) + VERSION_SIZE + CRC_SIZE], 'big'
)
# Unprocess data
data = self.unprocess_data(header, data[len(header) :])

Expand Down Expand Up @@ -614,7 +647,7 @@ def __str__(self) -> str:
return ', '.join(
[f"{k}={v.obj_type.__name__}({v.__get__(self)})" for k, v in self._autoserializable_fields()]
)

def as_dict(self) -> dict[str, typing.Any]:
return {k: v.__get__(self) for k, v in self._autoserializable_fields()}

Expand Down
3 changes: 3 additions & 0 deletions server/src/uds/core/util/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def __len__(self) -> int:
# Optimized methods, avoid re-reading from DB
def items(self) -> typing.Iterator[tuple[str, typing.Any]]: # type: ignore # compatible type
return iter(_decode_value(i.key, i.data) for i in self._filtered)

def keys(self) -> typing.Iterator[str]: # type: ignore # compatible type
return iter(_decode_value(i.key, i.data)[0] for i in self._filtered)

def values(self) -> typing.Iterator[typing.Any]: # type: ignore # compatible type
return iter(_decode_value(i.key, i.data)[1] for i in self._filtered)
Expand Down
Loading

0 comments on commit 525e9b9

Please sign in to comment.