Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 29 additions & 37 deletions scheduler/redis_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,9 @@ def deserialize(cls, data: Dict[str, Any]) -> Self:
class HashModel(BaseModel):
created_at: Optional[datetime] = None
parent: Optional[str] = None
_dirty_fields: Set[str] = dataclasses.field(default_factory=set) # fields that were changed
_save_all: bool = True # Save all fields to broker, after init, or after delete
_list_key: ClassVar[str] = ":list_all:"
_children_key_template: ClassVar[str] = ":children:{}:"

def __post_init__(self):
self._dirty_fields = set()
self._save_all = True

def __setattr__(self, key, value):
if key != "_dirty_fields" and hasattr(self, "_dirty_fields"):
self._dirty_fields.add(key)
super(HashModel, self).__setattr__(key, value)

@property
def _parent_key(self) -> Optional[str]:
if self.parent is None:
Expand All @@ -155,8 +144,10 @@ def exists(cls, name: str, connection: ConnectionType) -> bool:

@classmethod
def delete_many(cls, names: List[str], connection: ConnectionType) -> None:
for name in names:
connection.delete(cls._element_key_template.format(name))
with connection.pipeline() as pipeline:
for name in names:
pipeline.delete(cls._element_key_template.format(name))
pipeline.execute()

@classmethod
def get(cls, name: str, connection: ConnectionType) -> Optional[Self]:
Expand All @@ -171,34 +162,35 @@ def get(cls, name: str, connection: ConnectionType) -> Optional[Self]:

@classmethod
def get_many(cls, names: Sequence[str], connection: ConnectionType) -> List[Optional[Self]]:
pipeline = connection.pipeline()
for name in names:
pipeline.hgetall(cls._element_key_template.format(name))
values = pipeline.execute()
return [(cls.deserialize(decode_dict(v, set())) if v else None) for v in values]
with connection.pipeline() as pipeline:
for name in names:
pipeline.hgetall(cls._element_key_template.format(name))
values = pipeline.execute()
return [(cls.deserialize(decode_dict(v, set())) if v else None) for v in values]

def save(self, connection: ConnectionType) -> None:
connection.sadd(self._list_key, self.name)
if self._parent_key is not None:
connection.sadd(self._parent_key, self.name)
mapping = self.serialize(with_nones=True)
if not self._save_all and len(self._dirty_fields) > 0:
mapping = {k: v for k, v in mapping.items() if k in self._dirty_fields}
none_values = {k for k, v in mapping.items() if v is None}
if none_values:
connection.hdel(self._key, *none_values)
mapping = {k: v for k, v in mapping.items() if v is not None}
if mapping:
connection.hset(self._key, mapping=mapping)
self._dirty_fields = set()
self._save_all = False
with connection.pipeline() as pipeline:
pipeline.sadd(self._list_key, self.name)
if self._parent_key is not None:
pipeline.sadd(self._parent_key, self.name)
mapping = self.serialize(with_nones=True)
none_values = {k for k, v in mapping.items() if v is None}
if none_values:
pipeline.hdel(self._key, *none_values)
mapping = {k: v for k, v in mapping.items() if v is not None}
if mapping:
pipeline.hset(self._key, mapping=mapping)

pipeline.execute()

def delete(self, connection: ConnectionType) -> None:
connection.srem(self._list_key, self._key)
if self._parent_key is not None:
connection.srem(self._parent_key, 0, self._key)
connection.delete(self._key)
self._save_all = True
with connection.pipeline() as pipeline:
pipeline.srem(self._list_key, self._key)
if self._parent_key is not None:
pipeline.srem(self._parent_key, 0, self._key)
pipeline.delete(self._key)

pipeline.execute()

@classmethod
def count(cls, connection: ConnectionType, parent: Optional[str] = None) -> int:
Expand Down