Skip to content

Commit

Permalink
Fix MyPy errors
Browse files Browse the repository at this point in the history
Signed-off-by: Jean Snyman <git@jsnyman.com>
  • Loading branch information
stringlytyped committed May 31, 2024
1 parent 55eacb9 commit 2496836
Show file tree
Hide file tree
Showing 24 changed files with 605 additions and 345 deletions.
4 changes: 2 additions & 2 deletions keylime/cmd/registrar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
logger = keylime_logging.init_logging("registrar")


def _check_devid_requirements():
def _check_devid_requirements() -> None:
"""Checks that the cryptography package is the version needed for DevID support (>= 38). Exits if this requirement
is not met and DevID is the only identity allowable by the config.
"""
Expand All @@ -32,7 +32,7 @@ def _check_devid_requirements():


def main() -> None:
logger.info("Starting Keylime registrar service...")
logger.info("Starting Keylime registrar...")

config.check_version("registrar", logger=logger)

Expand Down
11 changes: 7 additions & 4 deletions keylime/keylime_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@
import logging
from logging import Logger
from logging import config as logging_config
from typing import Any, Callable, Dict
from typing import TYPE_CHECKING, Any, Callable, Dict

from keylime import config

if TYPE_CHECKING:
from logging import LogRecord

try:
logging_config.fileConfig(config.get_config("logging"))
except KeyError:
logging.basicConfig(format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s", level=logging.DEBUG)


request_id_var = contextvars.ContextVar("request_id")
request_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("request_id")


def set_log_func(loglevel: int, logger: Logger) -> Callable[..., None]:
Expand Down Expand Up @@ -49,7 +52,7 @@ def log_http_response(logger: Logger, loglevel: int, response_body: Dict[str, An
return True


def annotate_logger(logger):
def annotate_logger(logger: Logger) -> None:
request_id_filter = RequestIDFilter()

for handler in logger.handlers:
Expand All @@ -72,7 +75,7 @@ def init_logging(loggername: str) -> Logger:


class RequestIDFilter(logging.Filter):
def filter(self, record):
def filter(self, record: LogRecord) -> bool:
reqid = request_id_var.get("")

record.reqid = reqid
Expand Down
4 changes: 3 additions & 1 deletion keylime/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Checks whether script is being invoked by tox in a virtual environment
def is_tox_env():
def is_tox_env() -> bool:
# Import 'os' inside function to avoid polluting the namespace of any module which imports 'keylime.models'
import os # pylint: disable=import-outside-toplevel

Expand All @@ -13,3 +13,5 @@ def is_tox_env():
from keylime.models.base.da import da_manager
from keylime.models.base.db import db_manager
from keylime.models.registrar import *

__all__ = ["da_manager", "db_manager"]
19 changes: 18 additions & 1 deletion keylime/models/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,24 @@
from keylime.models.base.da import da_manager
from keylime.models.base.db import db_manager
from keylime.models.base.persistable_model import PersistableModel
from keylime.models.base.type import ModelType
from keylime.models.base.types.certificate import Certificate
from keylime.models.base.types.dictionary import Dictionary
from keylime.models.base.types.one_of import OneOf

__all__ = [
"BigInteger",
"Boolean",
"Float",
"Integer",
"LargeBinary",
"SmallInteger",
"String",
"Text",
"BasicModel",
"da_manager",
"db_manager",
"PersistableModel",
"Certificate",
"Dictionary",
"OneOf",
]
14 changes: 7 additions & 7 deletions keylime/models/base/associations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from keylime.models.base.model import Model
from keylime.models.base.basic_model import BasicModel


class ModelAssociation:
Expand Down Expand Up @@ -76,16 +76,16 @@ def name(self) -> str:
...

@property
def other_model(self) -> "Model":
def other_model(self) -> "BasicModel":
...

def __init__(self, name, other_model, *args, foreign_key=None, preload=True, **kwargs):
self._foreign_key = foreign_key
self._preload = preload

args = [name, other_model, *args]
args = (name, other_model, *args)

super().__init__(*args, **kwargs)
super().__init__(*args, **kwargs) # type: ignore

@property
def foreign_key(self):
Expand All @@ -107,7 +107,7 @@ def foreign_key(self):
return self._foreign_key

@property
def foreign_key_type(self):
def foreign_key_type(self) -> None:
foreign_key_info = self.other_model.fields.get(self.foreign_key)

if not foreign_key_info:
Expand All @@ -116,7 +116,7 @@ def foreign_key_type(self):
f"correspond to an field defined by the associated model '{self.other_model}'"
)

self._foreign_key_type = foreign_key_info.type
self._foreign_key_type = foreign_key_info.data_type

@property
def preload(self):
Expand Down Expand Up @@ -164,7 +164,7 @@ class HasManyAssociation(DatabasePersistenceMixin, GenericToManyAssociation):
# and rename BelongsToAssociation as ReferencesAssociation


class AssociatedRecordSet(set):
class AssociatedRecordSet(set["BasicModel"]):
def __init__(self, parent_record, association, *args, **kwargs):
self._parent_record = parent_record
self._association = association
Expand Down
34 changes: 23 additions & 11 deletions keylime/models/base/basic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import re
from abc import ABC, abstractmethod
from types import MappingProxyType
from typing import TypeAlias, Union

from sqlalchemy.types import TypeEngine

from keylime.models.base.associations import ModelAssociation
from keylime.models.base.errors import FieldValueInvalid, UndefinedField
from keylime.models.base.field import ModelField
from keylime.models.base.type import ModelType


class BasicModel(ABC):
Expand Down Expand Up @@ -42,6 +47,8 @@ def _schema(cls):
"""

DeclaredFieldType: TypeAlias = Union[ModelType, TypeEngine, type[ModelType], type[TypeEngine]]

BUILT_IN_INST_ATTRS = [
"record_values",
"_record_values",
Expand All @@ -53,20 +60,22 @@ def _schema(cls):
"changes_valid",
]

_schema_processed = False
_schema_processed: bool = False
_fields: dict[str, ModelField] = {}
_associations: dict[str, ModelAssociation] = {}

@classmethod
@abstractmethod
def _schema(cls):
def _schema(cls) -> None:
pass

@classmethod
def _clear_model(cls):
def _clear_model(cls) -> None:
cls._fields = {}
cls._associations = {}

@classmethod
def _process_schema(cls):
def _process_schema(cls) -> None:
# If schema has already been processed (and no changes have been made since), do not process again
if cls._schema_processed:
return
Expand All @@ -82,7 +91,7 @@ def _process_schema(cls):
cls._schema()

@classmethod
def _new_field(cls, name, type, nullable=False):
def _new_field(cls, name: str, type: DeclaredFieldType, nullable: bool = False) -> ModelField:
# pylint: disable=redefined-builtin

# Create new model field
Expand Down Expand Up @@ -226,11 +235,11 @@ def render(self, only=None):

value = getattr(self, name)

data[name] = field.type.render(value)
data[name] = field.data_type.render(value)

return data

def clear_changes(self):
def clear_changes(self) -> None:
self._changes.clear()
self._errors.clear()

Expand All @@ -246,10 +255,10 @@ def change(self, name, value):

try:
# Attempt to cast incoming value to field's declared type
self._changes[name] = field.type.cast(value)
self._changes[name] = field.data_type.cast(value)
except Exception:
# If above casting fails, produce a type mismatch message and add it the field's list of errors
self._add_error(name, field.type.generate_error_msg(value))
self._add_error(name, field.data_type.generate_error_msg(value))

def cast_changes(self, changes, permitted=None):
if permitted is None:
Expand All @@ -261,16 +270,19 @@ def cast_changes(self, changes, permitted=None):

self.change(name, value)

def commit_changes(self):
def commit_changes(self) -> None:
if not self.changes_valid:
raise FieldValueInvalid(f"pending changes for model '{self.__class__.__name__}' have validation errors")

for name, value in self._changes.items():
if value is None and not self.__class__.fields[name].nullable:
raise FieldValueInvalid(f"field 'name' for model '{self.__class__.__name__}' is not nullable")

self._record_values[name] = value

self.clear_changes()

def _force_commit_changes(self):
def _force_commit_changes(self) -> None:
for name, value in self._changes.items():
self._record_values[name] = value

Expand Down
14 changes: 9 additions & 5 deletions keylime/models/base/da.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import sys
from typing import TYPE_CHECKING, Optional

from keylime import config, keylime_logging
from keylime.da import record

if TYPE_CHECKING:
from keylime.da.record import BaseRecordManagement

logger = keylime_logging.init_logging("keylime_da")


class DAManager:
def __init__(self) -> None:
self._service = None
self._backend = None
self._service: Optional[str] = None
self._backend: Optional["BaseRecordManagement"] = None

def make_backend(self, service):
def make_backend(self, service: str) -> None:
self._service = service

try:
Expand All @@ -25,11 +29,11 @@ def make_backend(self, service):
sys.exit(1)

@property
def service(self):
def service(self) -> Optional[str]:
return self._service

@property
def backend(self):
def backend(self) -> Optional["BaseRecordManagement"]:
return self._backend


Expand Down
14 changes: 7 additions & 7 deletions keylime/models/base/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from configparser import NoOptionError
from contextlib import contextmanager
from sqlite3 import Connection as SQLite3Connection
from typing import Any, Dict, Optional, cast
from typing import Any, Dict, Iterator, Optional, cast

from sqlalchemy import create_engine, event
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session, registry, scoped_session, sessionmaker
from sqlalchemy.orm import Session, registry, scoped_session, sessionmaker # type: ignore[attr-defined]

from keylime import config, keylime_logging
from keylime.models.base.errors import BackendMissing
Expand All @@ -26,8 +26,8 @@ def _set_sqlite_pragma(dbapi_connection: SQLite3Connection, _) -> None:

class DBManager:
def __init__(self) -> None:
self._service = None
self._engine = None
self._service: Optional[str] = None
self._engine: Optional[Engine] = None
self._registry = None
self._scoped_session = None

Expand Down Expand Up @@ -87,9 +87,9 @@ def make_engine(self, service: str) -> Engine:
if config.DEBUG_DB and config.INSECURE_DEBUG:
engine_args["echo"] = True

self._engine = create_engine(url, **engine_args)
self._engine = create_engine(url, **engine_args) # type: ignore
self._registry = registry()
return self._engine
return self._engine # type: ignore

@property
def service(self) -> Optional[str]:
Expand Down Expand Up @@ -131,7 +131,7 @@ def session(self) -> Session:
return cast(Session, self._scoped_session())

@contextmanager
def session_context(self):
def session_context(self) -> Iterator[Session]:
session = self.session()

try:
Expand Down
Loading

0 comments on commit 2496836

Please sign in to comment.