Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code Review: Push Attestation API #7

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion keylime/cmd/verifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from keylime import cloud_verifier_tornado, config, keylime_logging
from keylime.common.migrations import apply
from keylime.mba import mba
import asyncio
import tornado.process

from keylime.web import VerifierServer
from keylime.models import da_manager, db_manager


logger = keylime_logging.init_logging("verifier")

Expand All @@ -12,7 +18,29 @@ def main() -> None:

# Explicitly load and initialize measured boot components
mba.load_imports()
cloud_verifier_tornado.main()

# TODO: Remove
# cloud_verifier_tornado.main()

""" db_manager.make_engine("cloud_verifier")
server = VerifierServer(http_port=8881)
tornado.process.fork_processes(0)
asyncio.run(server.start()) """

# Prepare to use the cloud_verifier database
db_manager.make_engine("cloud_verifier")
# Prepare backend for durable attestation, if configured
#da_manager.make_backend("cloud_verifier")

# Start HTTP server
server = VerifierServer(http_port=8881)
server.start_multi()

# db_manager.make_engine("cloud_verifier")

# server = VerifierServer()
# tornado.process.fork_processes(0)
# asyncio.run(server.start())


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion keylime/db/verifier_db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy import Column, ForeignKey, Integer, LargeBinary, PickleType, String, Text, schema
from sqlalchemy import Boolean, Column, ForeignKey, Integer, LargeBinary, PickleType, String, Text, schema
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship

Expand Down Expand Up @@ -51,6 +51,7 @@ class VerfierMain(Base):
last_received_quote = Column(Integer)
last_successful_attestation = Column(Integer)
tpm_clockinfo = Column(JSONPickleType(pickler=JSONPickler))
accept_attestations = Column(Boolean)


class VerifierAllowlist(Base):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Create attestations table and columns to enable push attestation protocol

Revision ID: 870c218abd9a
Revises: 330024be7bef
Create Date: 2024-02-23 00:02:34.715180

"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "870c218abd9a"
down_revision = "330024be7bef"
branch_labels = None
depends_on = None


def upgrade(engine_name):
globals()[f"upgrade_{engine_name}"]()


def downgrade(engine_name):
globals()[f"downgrade_{engine_name}"]()


def upgrade_registrar():
pass


def downgrade_registrar():
pass


def upgrade_cloud_verifier():
op.create_table(
"attestations",
sa.Column("agent_id", sa.String(80), sa.ForeignKey("verifiermain.agent_id")),
sa.Column("index", sa.Integer),
sa.Column("nonce", sa.LargeBinary(128)),
sa.Column("status", sa.String(8), server_default="waiting"),
sa.Column("tpm_quote", sa.Text, nullable=True),
sa.Column("hash_algorithm", sa.String(15)),
sa.Column("signing_scheme", sa.String(15)),
sa.Column("starting_ima_offset", sa.Integer),
sa.Column("ima_entries", sa.Text, nullable=True),
sa.Column("quoted_ima_entries_count", sa.Integer, nullable=True),
sa.Column("mb_entries", sa.LargeBinary, nullable=True),
# ISO8601 datetimes with microsecond precision in the UTC timezone are never more than 32 characters
sa.Column("nonce_created_at", sa.String(32)),
sa.Column("nonce_expires_at", sa.String(32)),
sa.Column("evidence_received_at", sa.String(32), nullable=True),
sa.Column("boottime", sa.String(32)),
sa.PrimaryKeyConstraint("agent_id", "index"),
)
op.add_column("verifiermain", sa.Column("accept_attestations", sa.Boolean))


def downgrade_cloud_verifier():
op.drop_table("attestations")
op.drop_column("verifiermain", "accept_attestations")
1 change: 1 addition & 0 deletions keylime/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def is_tox_env() -> bool:
if not is_tox_env():
from keylime.models.base.da import da_manager
from keylime.models.base.db import db_manager
from keylime.models.common import *
from keylime.models.registrar import *

__all__ = ["da_manager", "db_manager"]
13 changes: 11 additions & 2 deletions keylime/models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from sqlalchemy import BigInteger, Boolean, Float, Integer, LargeBinary, SmallInteger, String, Text
from sqlalchemy import BigInteger, Boolean, Float, Integer, SmallInteger, String, Text, desc, or_

from keylime.models.base.basic_model import BasicModel
from keylime.models.base.da import da_manager
from keylime.models.base.db import db_manager
from keylime.models.base.persistable_model import PersistableModel
from keylime.models.base.types.binary import Binary
from keylime.models.base.types.certificate import Certificate
from keylime.models.base.types.dictionary import Dictionary
from keylime.models.base.types.list import List
from keylime.models.base.types.nonce import Nonce
from keylime.models.base.types.one_of import OneOf
from keylime.models.base.types.timestamp import Timestamp

__all__ = [
"desc",
"or_",
"BigInteger",
"Boolean",
"Float",
"Integer",
"LargeBinary",
"SmallInteger",
"String",
"Text",
Expand All @@ -23,5 +28,9 @@
"PersistableModel",
"Certificate",
"Dictionary",
"List",
"OneOf",
"Binary",
"Nonce",
"Timestamp",
]
56 changes: 56 additions & 0 deletions keylime/models/base/associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ def other_model(self) -> type["BasicModel"]:

@property
def inverse_of(self) -> Optional[str]:
if not self._inverse_of:
# Get all associations which point from the associated model back to the parent model
candidates = [assoc for assoc in self.other_model.associations.values() if assoc.other_model == self]

# If there is only one such association in the associated model, the inverse association is unambiguous
if len(candidates) == 1:
self._inverse_of = candidates[0].name

return self._inverse_of

@property
Expand All @@ -144,6 +152,54 @@ def preload(self) -> bool:
pass


class EmbeddedAssociation(ModelAssociation):
@property
def preload(self) -> bool:
return True


class EmbedsOneAssociation(EmbeddedAssociation):
def __get__(
self, parent_record: "BasicModel", _objtype: Optional[type["BasicModel"]] = None
) -> Union["BasicModel", "EmbedsOneAssociation", None]:
return self._get_one(parent_record)

def __set__(self, parent_record: "BasicModel", other_record: "BasicModel") -> None:
self._set_one(parent_record, other_record)

if TYPE_CHECKING:

def _get_one(self, parent_record: "BasicModel") -> Union["BasicModel", "EmbedsOneAssociation", None]:
...


class EmbedsManyAssociation(EmbeddedAssociation):
def __get__(
self, parent_record: "BasicModel", _objtype: Optional[type["BasicModel"]] = None
) -> Union["AssociatedRecordSet", "EmbedsManyAssociation", None]:
return self._get_many(parent_record)

if TYPE_CHECKING:

def _get_many(self, parent_record: "BasicModel") -> Union[AssociatedRecordSet, "EmbedsManyAssociation", None]:
...


class EmbeddedInAssociation(EmbeddedAssociation):
def __get__(
self, parent_record: "BasicModel", _objtype: Optional[type["BasicModel"]] = None
) -> Union["BasicModel", "EmbedsOneAssociation", None]:
return self._get_one(parent_record)

def __set__(self, parent_record: "BasicModel", other_record: "BasicModel") -> None:
self._set_one(parent_record, other_record)

if TYPE_CHECKING:

def _get_one(self, parent_record: "BasicModel") -> Union["BasicModel", "EmbedsOneAssociation", None]:
...


class EntityAssociation(ModelAssociation):
"""EntityAssociation extends ModelAssociation to provide additional functionality common to associations which
map to a relationship between database entities.
Expand Down
87 changes: 77 additions & 10 deletions keylime/models/base/basic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,22 +249,25 @@ def _init_from_dict(self, data: dict, _process_associations: bool) -> None:

def __setattr__(self, name: str, value: Any) -> None:
if (
name not in dir(self)
not name.startswith("_")
and name not in dir(self)
and name not in type(self).INST_ATTRS
and name not in type(self).fields.keys()
and name not in type(self).associations.keys()
):
msg = f"model '{type(self).__name__}' does not define a field or attribute with name '{name}'"
msg = f"cannot set '{name}' as model '{type(self).__name__}' defines no field or attribute with that name"
raise AttributeError(msg)

super().__setattr__(name, value)

def __getattribute__(self, name: str) -> Any:
# TODO: Uncomment
""" def __getattribute__(self, name: str) -> Any:
try:
return super().__getattribute__(name)
except AttributeError:
msg = f"model '{type(self).__name__}' does not define a field or attribute with name '{name}'"
msg = f"cannot get '{name}' as model '{type(self).__name__}' defines no field or attribute with that name"
raise AttributeError(msg) from None
"""

def __repr__(self) -> str:
"""Returns a code-like string representation of the model instance
Expand Down Expand Up @@ -435,16 +438,66 @@ def validate_format(self, field: str, format: Union[str, Pattern], msg: Optional
if not re.fullmatch(format, value):
self._add_error(field, msg or "does not have the correct format")

def validate_snake_case(self, fields: Union[str, Sequence[str]], msg: Optional[str] = None):
msg = msg or "can only contain lowercase letters, numbers and underscores"

for field in fields:
self.validate_format(field, "[a-z0-9_]+", msg)

def _render_field(self, name: str, data: dict[str, Any]) -> dict[str, Any]:
field = self.__class__.fields[name]
value = getattr(self, name, None)

if field.render:
data[name] = field.data_type.render(value)

return data

def _render_embedded_record(self, name: str, data: dict[str, Any]) -> dict[str, Any]:
value = getattr(self, name, None)
data[name] = value.render() if value is not None else None
return data

def _render_embedded_record_set(self, name: str, data: dict[str, Any]) -> dict[str, Any]:
value = getattr(self, name, None)

if value is None:
data[name] = None
else:
data[name] = []

for record in value:
data[name].append(record.render())

return data

def render(self, only: Optional[Sequence[str]] = None) -> dict[str, Any]:
data = {}

for name, field in self.__class__.fields.items():
if only and name not in only:
# Locate record members by iterating through schema helper calls so that fields ane embedded records are
# rendered in the order in which they were defined
for _, args, kwargs in self.__class__.schema_helpers_used:
# Get member name from kwargs passed to schema helper or first positional argument
first_arg = args[0] if len(args) > 0 else None
name = kwargs.get("name") or first_arg

# If the "name" obtained does not look like a name, skip
if not name or not isinstance(name, str):
continue

value = getattr(self, name)
# If member name not present in provided allowlist, skip
if only and name not in only:
continue

data[name] = field.data_type.render(value)
# If member name is the name of a field, render its value according to the field's data type
if name in self.__class__.fields.keys():
self._render_field(name, data)
# If member name is the name of an embedded record, render the full record
elif name in self.__class__.embeds_one_associations.keys():
self._render_embedded_record(name, data)
# If member name is the name of a set of embedded records, render as a list of records
elif name in self.__class__.embeds_many_associations.keys():
self._render_embedded_record_set(name, data)

return data

Expand All @@ -461,9 +514,23 @@ def values(self) -> Mapping[str, Any]:
return MappingProxyType({**self.committed, **self.changes})

@property
def errors(self) -> Mapping[str, Sequence[str]]:
def errors(self) -> Mapping[str, Sequence | Mapping]:
errors = {field: errors for field, errors in self._errors.items() if len(errors) > 0}
return MappingProxyType(errors)

embeds = {**self.__class__.embeds_one_associations, **self.__class__.embeds_many_associations}

for name, embed in self.__class__.embeds_one_associations.items():
(record,) = embed.get_record_set(self) or (None,)

if record and record.errors:
errors.update({f"{name}.{fname}": error for fname, error in record.errors.items()})

for name, embed in self.__class__.embeds_many_associations.items():
for index, record in enumerate(embed.get_record_set(self)):
if record.errors:
errors.update({f"{name}[{index}].{fname}": error for fname, error in record.errors.items()})

return errors

@property
def changes_valid(self) -> bool:
Expand Down
Loading
Loading