diff --git a/pyproject.toml b/pyproject.toml index 81b63c7a..fdf60041 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,12 +27,10 @@ exclude = ''' [tool.mypy] mypy_path = "mypy_stubs" exclude = [ - 'src/tpm2_pytss/internal/templates.py', 'src/tpm2_pytss/encoding.py', 'src/tpm2_pytss/policy.py', 'src/tpm2_pytss/ESAPI.py', 'src/tpm2_pytss/FAPI.py', - 'src/tpm2_pytss/types.py', 'src/tpm2_pytss/internal/crypto.py', 'src/tpm2_pytss/fapi_info.py', ] \ No newline at end of file diff --git a/src/tpm2_pytss/_libtpm2_pytss/ffi.pyi b/src/tpm2_pytss/_libtpm2_pytss/ffi.pyi index 8eb7ee65..9ef894b7 100644 --- a/src/tpm2_pytss/_libtpm2_pytss/ffi.pyi +++ b/src/tpm2_pytss/_libtpm2_pytss/ffi.pyi @@ -1,4 +1,4 @@ -from typing import Optional, Callable, Iterable, Any, Dict +from typing import Optional, Callable, Iterable, Any, Dict, Tuple error: type[Exception] class CData: @@ -13,7 +13,7 @@ class CType: kind: str cname: str item: "CType" - fields: Iterable[str] + fields: Iterable[Tuple[str, Any]] NULL: CData def gc(cdata: CData, destructor: Callable[[CData], None], size: int = 0)-> CData: ... @@ -28,3 +28,4 @@ def from_handle(handle: CData) -> Any: ... def new_handle(python_object: Any) -> CData: ... def cast(ctype: str, value: CData) -> CData: ... def memmove(dest: CData | bytes, src: CData | bytes, n: int) -> None: ... +def addressof(cdata: CData, *fields_or_indexes: str | int) -> CData: ... diff --git a/src/tpm2_pytss/constants.py b/src/tpm2_pytss/constants.py index 66dfcd58..870f309e 100644 --- a/src/tpm2_pytss/constants.py +++ b/src/tpm2_pytss/constants.py @@ -23,6 +23,12 @@ SupportsIndex, ) +try: + from typing import Self +except ImportError: + # assume mypy is running on python 3.11+ + pass + if TYPE_CHECKING: from .ESAPI import ESAPI @@ -33,14 +39,14 @@ class TPM_FRIENDLY_INT(int): _FIXUP_MAP: Dict[str, str] = {} @classmethod - def parse(cls, value: str) -> int: + def parse(cls, value: str) -> "Self": # If it's a string initializer value, see if it matches anything in the list if isinstance(value, str): try: x = _CLASS_INT_ATTRS_from_string(cls, value, cls._FIXUP_MAP) if not isinstance(x, int): raise KeyError(f'Expected int got: "{type(x)}"') - return x + return cls(x) except KeyError: raise ValueError( f'Could not convert friendly name to value, got: "{value}"' @@ -305,7 +311,7 @@ class TPMA_FRIENDLY_INTLIST(TPM_FRIENDLY_INT): _MASKS: Tuple[Tuple[int, int, str], ...] = tuple() @classmethod - def parse(cls, value: str) -> int: + def parse(cls, value: str) -> "Self": """ Converts a string of | separated constant values into it's integer value. Given a pipe "|" separated list of string constant values that represent the @@ -361,7 +367,7 @@ def parse(cls, value: str) -> int: f'Could not convert friendly name to value, got: "{k}"' ) - return intvalue + return cls(intvalue) def __str__(self) -> str: """Given a constant, return the string bitwise representation. diff --git a/src/tpm2_pytss/internal/crypto.py b/src/tpm2_pytss/internal/crypto.py index 120d0b30..294e6f82 100644 --- a/src/tpm2_pytss/internal/crypto.py +++ b/src/tpm2_pytss/internal/crypto.py @@ -23,7 +23,7 @@ from cryptography.hazmat.primitives.ciphers import modes, Cipher, CipherAlgorithm from cryptography.hazmat.backends import default_backend from cryptography.exceptions import UnsupportedAlgorithm, InvalidSignature -from typing import Tuple, Type, Union +from typing import Tuple, Type, Union, Optional import secrets import sys @@ -140,7 +140,9 @@ def key_from_encoding(data, password=None): raise ValueError("Unsupported key format") -def _public_from_encoding(data, obj, password=None): +def _public_from_encoding( + data: bytes, obj: "TPMT_PUBLIC", password: Optional[bytes] = None +) -> None: key = key_from_encoding(data, password) nums = key.public_numbers() if isinstance(key, rsa.RSAPublicKey): @@ -183,7 +185,9 @@ def private_key_from_encoding(data, password=None): raise ValueError("Unsupported key format") -def _private_from_encoding(data, obj, password=None): +def _private_from_encoding( + data: bytes, obj: "TPMT_SENSITIVE", password: Optional[bytes] = None +): key = private_key_from_encoding(data, password) nums = key.private_numbers() if isinstance(key, rsa.RSAPrivateKey): @@ -274,7 +278,9 @@ def _generate_d(p, q, e, n): return d -def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC"): +def private_to_key( + private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC" +) -> Union[rsa.RSAPrivateKey, ec.EllipticCurvePrivateKey]: key = None if private.sensitiveType == TPM2_ALG.RSA: @@ -310,7 +316,7 @@ def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC") return key -def _public_to_pem(obj, encoding="pem"): +def _public_to_pem(obj: "TPMT_PUBLIC", encoding: str = "pem") -> bytes: encoding = encoding.lower() key = public_to_key(obj) if encoding == "pem": @@ -323,7 +329,7 @@ def _public_to_pem(obj, encoding="pem"): raise ValueError(f"unsupported encoding: {encoding}") -def _getname(obj): +def _getname(obj: "TPMT_PUBLIC") -> bytes: dt = _get_digest(obj.nameAlg) if dt is None: raise ValueError(f"unsupported digest algorithm: {obj.nameAlg}") @@ -391,7 +397,7 @@ def _symdef_to_crypt(symdef: "TPMT_SYM_DEF"): return (alg, mode, bits) -def _calculate_sym_unique(nameAlg, secret, seed): +def _calculate_sym_unique(nameAlg: TPM2_ALG, secret: bytes, seed: bytes) -> bytes: dt = _get_digest(nameAlg) if dt is None: raise ValueError(f"unsupported digest algorithm: {nameAlg}") @@ -409,7 +415,7 @@ def _get_digest_size(alg: TPM2_ALG) -> int: return dt.digest_size -def _get_signature_bytes(sig): +def _get_signature_bytes(sig: "TPMT_SIGNATURE") -> bytes: if sig.sigAlg in (TPM2_ALG.RSAPSS, TPM2_ALG.RSASSA): rb = bytes(sig.signature.rsapss.sig) elif sig.sigAlg == TPM2_ALG.ECDSA: @@ -476,7 +482,9 @@ def verify_signature_hmac(signature, key, data): h.verify(sig) -def _verify_signature(signature, key, data): +def _verify_signature( + signature: "TPMT_SIGNATURE", key: "TPMT_PUBLIC", data: bytes +) -> None: if hasattr(key, "publicArea"): key = key.publicArea kt = getattr(key, "type", None) diff --git a/src/tpm2_pytss/internal/utils.py b/src/tpm2_pytss/internal/utils.py index 7ba2dc56..680ae0b8 100644 --- a/src/tpm2_pytss/internal/utils.py +++ b/src/tpm2_pytss/internal/utils.py @@ -254,7 +254,7 @@ def _cpointer_to_ctype(x: ffi.CData) -> ffi.CType: def _fixup_cdata_kwargs( - this: Any, _cdata: ffi.CData, kwargs: Dict[str, Any] + this: Any, _cdata: Any, kwargs: Dict[str, Any] ) -> Tuple[ffi.CData, Dict[str, Any]]: # folks may call this routine without a keyword argument which means it may diff --git a/src/tpm2_pytss/types.py b/src/tpm2_pytss/types.py index 73ea0d79..33fb6cd2 100644 --- a/src/tpm2_pytss/types.py +++ b/src/tpm2_pytss/types.py @@ -40,12 +40,12 @@ TPM2_SE, TPM2_HR, ) -from typing import Union, Tuple, Optional, Dict, Any +from typing import Union, Tuple, Optional, Any, Iterable, List try: - # assume mypy is running on python 3.11+ from typing import Self except ImportError: + # assume mypy is running on python 3.11+ pass import sys @@ -77,7 +77,7 @@ class TPM2_HANDLE(int): class TPM_OBJECT(object): """ Abstract Base class for all TPM Objects. Not suitable for direct instantiation.""" - def __init__(self, _cdata: Optional[ffi.CData] = None, **kwargs: Any): + def __init__(self, _cdata: Optional[Any] = None, **kwargs: Any): # Rather than trying to mock the FFI interface, just avoid it and return # the base object. This is really only needed for documentation, and it @@ -120,7 +120,7 @@ def __init__(self, _cdata: Optional[ffi.CData] = None, **kwargs: Any): v = subobj TPM_OBJECT.__setattr__(self, k, v) - def __getattribute__(self, key): + def __getattribute__(self, key: str) -> Any: try: # go through object to avoid invoking THIS objects __getattribute__ call # and thus infinite recursion @@ -143,7 +143,7 @@ def __getattribute__(self, key): obj = _convert_to_python_native(globals(), x, parent=self._cdata) return obj - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: _value = value _cdata = object.__getattribute__(self, "_cdata") @@ -205,8 +205,8 @@ def __setattr__(self, key, value): # recurse so we can get handling of setattr with Python wrapped data setattr(self, key, value) - def __dir__(self): - return object.__dir__(self) + dir(self._cdata) + def __dir__(self) -> Iterable[str]: + return list(object.__dir__(self)) + dir(self._cdata) def marshal(self) -> bytes: """Marshal instance into bytes. @@ -273,14 +273,14 @@ def __init__(self, _cdata: Optional[Union[ffi.CData, bytes]] = None, **kwargs: A super().__init__(_cdata=_cdata) @classmethod - def _get_bytefield(cls): + def _get_bytefield(cls) -> Optional[str]: tipe = ffi.typeof(f"{cls.__name__}") for f in tipe.fields: if f[0] != "size": return f[0] return None - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: if key == "size": raise AttributeError(f"{key} is read only") @@ -294,7 +294,7 @@ def __setattr__(self, key, value): else: super().__setattr__(key, value) - def __getattribute__(self, key): + def __getattribute__(self, key: str) -> Any: _bytefield = type(self)._get_bytefield() if key == _bytefield: b = getattr(self._cdata, _bytefield) @@ -302,11 +302,13 @@ def __getattribute__(self, key): return memoryview(ffi.buffer(rb, self._cdata.size)) return super().__getattribute__(key) - def __len__(self): - return self._cdata.size + def __len__(self) -> int: + return int(self._cdata.size) - def __getitem__(self, index): + def __getitem__(self, index: slice) -> Any: _bytefield = type(self)._get_bytefield() + if _bytefield is None: + raise RuntimeError("unable to find byte field") buf = getattr(self, _bytefield) if isinstance(index, int): if index >= self._cdata.size: @@ -317,8 +319,10 @@ def __getitem__(self, index): else: raise TypeError("index must an int or a slice") - def __bytes__(self): + def __bytes__(self) -> bytes: _bytefield = type(self)._get_bytefield() + if _bytefield is None: + raise RuntimeError("unable to find byte field") buf = getattr(self, _bytefield) return bytes(buf) @@ -337,7 +341,7 @@ def __str__(self) -> str: b = self.__bytes__() return binascii.hexlify(b).decode() - def __eq__(self, value): + def __eq__(self, value: object) -> bool: b = self.__bytes__() return b == value @@ -352,14 +356,14 @@ class TPML_Iterator(object): do_something(alg) """ - def __init__(self, tpml): + def __init__(self, tpml: "TPML_OBJECT"): self._tpml = tpml self._index = 0 - def __iter__(self): + def __iter__(self) -> "Self": return self - def __next__(self): + def __next__(self) -> Any: if self._index > self._tpml.count - 1: raise StopIteration @@ -373,7 +377,7 @@ class TPML_OBJECT(TPM_OBJECT): """ Abstract Base class for all TPML Objects. A TPML object is an object that contains a list of objects. This is not suitable for direct instantiation.""" - def __init__(self, _cdata: Optional[ffi.CData] = None, **kwargs: Dict[str, Any]): + def __init__(self, _cdata: Optional[Any] = None, **kwargs: Any): _cdata, kwargs = _fixup_cdata_kwargs(self, _cdata, kwargs) super().__init__(_cdata=_cdata) @@ -417,7 +421,7 @@ def __init__(self, _cdata: Optional[ffi.CData] = None, **kwargs: Dict[str, Any]) self._cdata.count = len(kwargs[key]) - def __getattribute__(self, key): + def __getattribute__(self, key: str) -> Any: try: # Can the parent handle it? @@ -455,10 +459,10 @@ def __getattribute__(self, key): return l - def __getitem__(self, item): + def __getitem__(self, item: Union[int, slice]) -> Any: item_was_int = isinstance(item, int) try: - return object.__getitem__(self, item) + return getattr(object, "__getitem__")(self, item) except AttributeError: pass @@ -502,11 +506,11 @@ def __getitem__(self, item): return objects[0] if item_was_int else objects - def __len__(self): + def __len__(self) -> int: - return self._cdata.count + return int(self._cdata.count) - def __setitem__(self, key, value): + def __setitem__(self, key: Union[int, slice], value: Any) -> None: if not isinstance(key, (int, slice)): raise TypeError(f"list indices must be integers or slices, not {type(key)}") @@ -543,7 +547,7 @@ def __setitem__(self, key, value): if key.stop > self._cdata.count: self._cdata.count = key.stop - def __iter__(self): + def __iter__(self) -> TPML_Iterator: return TPML_Iterator(self) @@ -563,29 +567,29 @@ class TPM2B_NAME(TPM2B_SIMPLE_OBJECT): pass -def _handle_sym_common(objstr, default_mode="null"): +def _handle_sym_common(objstr: str, default_mode: str = "null") -> Tuple[int, TPM2_ALG]: if objstr is None or len(objstr) == 0: objstr = "128" - bits = objstr[:3] + bitstr = objstr[:3] expected = ["128", "192", "256"] - if bits not in expected: - raise ValueError(f'Expected bits to be one of {expected}, got: "{bits}"') + if bitstr not in expected: + raise ValueError(f'Expected bits to be one of {expected}, got: "{bitstr}"') - bits = int(bits) + bits = int(bitstr) # go past bits objstr = objstr[3:] if len(objstr) == 0: - mode = default_mode + modestr = default_mode else: expected = ["cfb", "cbc", "ofb", "ctr", "ecb"] if objstr not in expected: raise ValueError(f'Expected mode to be one of {expected}, got: "{objstr}"') - mode = objstr + modestr = objstr - mode = TPM2_ALG.parse(mode) + mode = TPM2_ALG.parse(modestr) return (bits, mode) @@ -593,7 +597,7 @@ def _handle_sym_common(objstr, default_mode="null"): class TPMT_SYM_DEF(TPM_OBJECT): @classmethod def parse( - cls, alg: str, is_restricted: bool = False, is_rsapss: bool = False + cls, alg: Optional[str], is_restricted: bool = False, is_rsapss: bool = False ) -> "TPMT_SYM_DEF": """Builds a TPMT_SYM_DEF from a tpm2-tools like specifier strings. @@ -655,7 +659,7 @@ class TPMT_SYM_DEF_OBJECT(TPMT_SYM_DEF): class TPMT_PUBLIC(TPM_OBJECT): @staticmethod - def _handle_rsa(objstr, templ): + def _handle_rsa(objstr: str, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.RSA if objstr is None or objstr == "": @@ -673,7 +677,7 @@ def _handle_rsa(objstr, templ): return True @staticmethod - def _handle_ecc(objstr, templ): + def _handle_ecc(objstr: str, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.ECC if objstr is None or objstr == "": @@ -689,7 +693,7 @@ def _handle_ecc(objstr, templ): return True @staticmethod - def _handle_aes(objstr, templ): + def _handle_aes(objstr: str, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.SYMCIPHER templ.parameters.symDetail.sym.algorithm = TPM2_ALG.AES @@ -699,7 +703,7 @@ def _handle_aes(objstr, templ): return False @staticmethod - def _handle_camellia(objstr, templ): + def _handle_camellia(objstr: str, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.SYMCIPHER templ.parameters.symDetail.sym.algorithm = TPM2_ALG.CAMELLIA @@ -710,7 +714,7 @@ def _handle_camellia(objstr, templ): return False @staticmethod - def _handle_sm4(objstr, templ): + def _handle_sm4(objstr: str, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.SYMCIPHER templ.parameters.symDetail.sym.algorithm = TPM2_ALG.SM4 @@ -723,28 +727,28 @@ def _handle_sm4(objstr, templ): return False @staticmethod - def _handle_xor(_, templ): + def _handle_xor(_: Any, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.KEYEDHASH templ.parameters.keyedHashDetail.scheme.scheme = TPM2_ALG.XOR return True @staticmethod - def _handle_hmac(_, templ): + def _handle_hmac(_: Any, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.KEYEDHASH templ.parameters.keyedHashDetail.scheme.scheme = TPM2_ALG.HMAC return True @staticmethod - def _handle_keyedhash(_, templ): + def _handle_keyedhash(_: Any, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.KEYEDHASH templ.parameters.keyedHashDetail.scheme.scheme = TPM2_ALG.NULL return False @staticmethod - def _error_on_conflicting_sign_attrs(templ): + def _error_on_conflicting_sign_attrs(templ: "TPMT_PUBLIC") -> None: """ If the scheme is set, both the encrypt and decrypt attributes cannot be set, check to see if this is the case, and turn down: @@ -770,7 +774,7 @@ def _error_on_conflicting_sign_attrs(templ): ) @staticmethod - def _handle_scheme_rsa(scheme, templ): + def _handle_scheme_rsa(scheme: Optional[str], templ: "TPMT_PUBLIC") -> bool: if scheme is None or len(scheme) == 0: scheme = "null" @@ -807,7 +811,7 @@ def _handle_scheme_rsa(scheme, templ): return True @staticmethod - def _handle_scheme_ecc(scheme, templ): + def _handle_scheme_ecc(scheme: Optional[str], templ: "TPMT_PUBLIC") -> bool: if scheme is None or len(scheme) == 0: scheme = "null" @@ -847,7 +851,7 @@ def _handle_scheme_ecc(scheme, templ): return True @staticmethod - def _handle_scheme_keyedhash(scheme, templ): + def _handle_scheme_keyedhash(scheme: Optional[str], templ: "TPMT_PUBLIC") -> None: if scheme is None or scheme == "": scheme = "sha256" @@ -866,7 +870,7 @@ def _handle_scheme_keyedhash(scheme, templ): ) @staticmethod - def _handle_scheme(scheme, templ): + def _handle_scheme(scheme: Optional[str], templ: "TPMT_PUBLIC") -> None: if templ.type == TPM2_ALG.RSA: TPMT_PUBLIC._handle_scheme_rsa(scheme, templ) elif templ.type == TPM2_ALG.ECC: @@ -880,7 +884,7 @@ def _handle_scheme(scheme, templ): ) @staticmethod - def _handle_asymdetail(detail, templ): + def _handle_asymdetail(detail: Optional[str], templ: "TPMT_PUBLIC") -> None: if templ.type == TPM2_ALG.KEYEDHASH: if detail is not None: @@ -908,7 +912,7 @@ def parse( TPMA_OBJECT, int, str ] = TPMA_OBJECT.DEFAULT_TPM2_TOOLS_CREATE_ATTRS, nameAlg: Union[TPM2_ALG, int, str] = "sha256", - authPolicy: bytes = None, + authPolicy: Optional[bytes] = None, ) -> "TPMT_PUBLIC": """Builds a TPMT_PUBLIC from a tpm2-tools like specifier strings. @@ -975,9 +979,9 @@ def parse( keep_processing = False prefix = tuple(filter(lambda x: objstr.startswith(x), expected)) if len(prefix) == 1: - prefix = prefix[0] - keep_processing = getattr(TPMT_PUBLIC, f"_handle_{prefix}")( - objstr[len(prefix) :], templ + prefixstr = prefix[0] + keep_processing = getattr(TPMT_PUBLIC, f"_handle_{prefixstr}")( + objstr[len(prefixstr) :], templ ) else: raise ValueError( @@ -1010,9 +1014,9 @@ def from_pem( objectAttributes: Union[TPMA_OBJECT, int] = ( TPMA_OBJECT.DECRYPT | TPMA_OBJECT.SIGN_ENCRYPT | TPMA_OBJECT.USERWITHAUTH ), - symmetric: TPMT_SYM_DEF_OBJECT = None, - scheme: TPMT_ASYM_SCHEME = None, - password: bytes = None, + symmetric: Optional[TPMT_SYM_DEF_OBJECT] = None, + scheme: Optional[TPMT_ASYM_SCHEME] = None, + password: Optional[bytes] = None, ) -> "TPMT_PUBLIC": """Decode the public part from standard key encodings. @@ -1193,6 +1197,8 @@ class TPM2B_MAX_NV_BUFFER(TPM2B_SIMPLE_OBJECT): class TPM2B_NV_PUBLIC(TPM_OBJECT): + nvPublic: "TPMS_NV_PUBLIC" + def get_name(self) -> TPM2B_NAME: """Get the TPM name of the NV public area. @@ -1221,6 +1227,8 @@ class TPM2B_PRIVATE_VENDOR_SPECIFIC(TPM2B_SIMPLE_OBJECT): class TPM2B_PUBLIC(TPM_OBJECT): + publicArea: TPMT_PUBLIC + @classmethod def from_pem( cls, @@ -1229,9 +1237,9 @@ def from_pem( objectAttributes: Union[TPMA_OBJECT, int] = ( TPMA_OBJECT.DECRYPT | TPMA_OBJECT.SIGN_ENCRYPT | TPMA_OBJECT.USERWITHAUTH ), - symmetric: TPMT_SYM_DEF_OBJECT = None, - scheme: TPMT_ASYM_SCHEME = None, - password: bytes = None, + symmetric: Optional[TPMT_SYM_DEF_OBJECT] = None, + scheme: Optional[TPMT_ASYM_SCHEME] = None, + password: Optional[bytes] = None, ) -> "TPM2B_PUBLIC": """Decode the public part from standard key encodings. @@ -1341,10 +1349,12 @@ def get_name(self) -> TPM2B_NAME: @classmethod def parse( cls, - alg="rsa", - objectAttributes=TPMA_OBJECT.DEFAULT_TPM2_TOOLS_CREATE_ATTRS, - nameAlg="sha256", - authPolicy=None, + alg: str = "rsa", + objectAttributes: TPMA_OBJECT = TPMA_OBJECT( + TPMA_OBJECT.DEFAULT_TPM2_TOOLS_CREATE_ATTRS + ), + nameAlg: Union[TPM2_ALG, int, str] = "sha256", + authPolicy: Optional[bytes] = None, ) -> "TPM2B_PUBLIC": """Builds a TPM2B_PUBLIC from a tpm2-tools like specifier strings. @@ -1399,6 +1409,8 @@ class TPMT_KEYEDHASH_SCHEME(TPM_OBJECT): class TPM2B_SENSITIVE(TPM_OBJECT): + sensitiveArea: "TPMT_SENSITIVE" + @classmethod def from_pem( cls, data: bytes, password: Optional[bytes] = None @@ -1434,8 +1446,8 @@ def keyedhash_from_secret( objectAttributes: Union[TPMA_OBJECT, int] = ( TPMA_OBJECT.DECRYPT | TPMA_OBJECT.SIGN_ENCRYPT | TPMA_OBJECT.USERWITHAUTH ), - scheme: TPMT_KEYEDHASH_SCHEME = None, - seed: bytes = None, + scheme: Optional[TPMT_KEYEDHASH_SCHEME] = None, + seed: Optional[bytes] = None, ) -> Tuple["TPM2B_SENSITIVE", TPM2B_PUBLIC]: """Generate the private and public part for a keyed hash object from a secret. @@ -1461,7 +1473,7 @@ def keyedhash_from_secret( (sens, pub) = TPM2B_SENSITIVE.keyedhash_from_secret(secret, scheme=scheme) """ sa, pa = TPMT_SENSITIVE.keyedhash_from_secret( - secret, nameAlg, objectAttributes, scheme, seed + secret, TPM2_ALG(nameAlg), TPMA_OBJECT(objectAttributes), scheme, seed ) priv = TPM2B_SENSITIVE(sensitiveArea=sa) pub = TPM2B_PUBLIC(publicArea=pa) @@ -1477,7 +1489,7 @@ def symcipher_from_secret( objectAttributes: Union[TPMA_OBJECT, int] = ( TPMA_OBJECT.DECRYPT | TPMA_OBJECT.SIGN_ENCRYPT | TPMA_OBJECT.USERWITHAUTH ), - seed: bytes = None, + seed: Optional[bytes] = None, ) -> Tuple["TPM2B_SENSITIVE", TPM2B_PUBLIC]: """Generate the private and public part for a symcipher object from a secret. @@ -1499,13 +1511,18 @@ def symcipher_from_secret( sens, pub = TPM2B_SENSITIVE.symcipher_from_secret(secret) """ sa, pa = TPMT_SENSITIVE.symcipher_from_secret( - secret, algorithm, mode, nameAlg, objectAttributes, seed + secret, + TPM2_ALG(algorithm), + TPM2_ALG(mode), + TPM2_ALG(nameAlg), + TPMA_OBJECT(objectAttributes), + seed, ) priv = TPM2B_SENSITIVE(sensitiveArea=sa) pub = TPM2B_PUBLIC(publicArea=pa) return (priv, pub) - def to_pem(self, public: TPMT_PUBLIC, password=None) -> bytes: + def to_pem(self, public: TPMT_PUBLIC, password: Optional[bytes] = None) -> bytes: """Encode the key as PEM encoded ASN.1. Args: @@ -1550,7 +1567,7 @@ def to_der(self, public: TPMT_PUBLIC) -> bytes: return self.sensitiveArea.to_der(public) - def to_ssh(self, public: TPMT_PUBLIC, password: bytes = None) -> bytes: + def to_ssh(self, public: TPMT_PUBLIC, password: Optional[bytes] = None) -> bytes: """Encode the key as OPENSSH PEM format. Args: @@ -1723,9 +1740,9 @@ def parse(selections: str) -> "TPML_PCR_SELECTION": f"got {len(selectors)}" ) - selections = [TPMS_PCR_SELECTION.parse(x) for x in selectors] + parsed_selections = [TPMS_PCR_SELECTION.parse(x) for x in selectors] - return TPML_PCR_SELECTION(selections) + return TPML_PCR_SELECTION(parsed_selections) class TPML_TAGGED_PCR_PROPERTY(TPML_OBJECT): @@ -1824,7 +1841,11 @@ def from_tools(cls, data: bytes) -> "TPMS_CONTEXT": ctx.contextBlob, _ = TPM2B_CONTEXT_DATA.unmarshal(data[24:]) return ctx - def to_tools(self, session_type: TPM2_SE = None, auth_hash: TPM2_ALG = None): + def to_tools( + self, + session_type: Optional[TPM2_SE] = None, + auth_hash: Optional[TPM2_ALG] = None, + ) -> bytes: """Marshal the context into a tpm2-tools context blob. Args: @@ -1850,12 +1871,11 @@ def to_tools(self, session_type: TPM2_SE = None, auth_hash: TPM2_ALG = None): ) version = 1 - if session_type is not None: - version = 2 data = b"" - if version == 2: + if isinstance(session_type, TPM2_SE) and isinstance(auth_hash, TPM2_ALG): + version = 2 data = int(0xBADCC0DE).to_bytes(4, "big") + version.to_bytes(4, "big") data = data + session_type.to_bytes(1, "big") data = data + auth_hash.to_bytes(2, "big") @@ -1925,7 +1945,9 @@ class TPMS_PCR_SELECT(TPM_OBJECT): class TPMS_PCR_SELECTION(TPM_OBJECT): - def __init__(self, pcrs=None, **kwargs): + def __init__( + self, pcrs: Optional[Union[str, List[str], List[int]]] = None, **kwargs: Any + ): super().__init__(**kwargs) if not pcrs: @@ -1943,6 +1965,8 @@ def __init__(self, pcrs=None, **kwargs): return for pcr in pcrs: + if isinstance(pcr, str): + pcr = int(pcr) if pcr < 0 or pcr > lib.TPM2_PCR_LAST: raise ValueError(f"PCR Index out of range, got {pcr}") self._cdata.pcrSelect[pcr // 8] |= 1 << (pcr % 8) @@ -1987,6 +2011,7 @@ def parse(selection: str) -> "TPMS_PCR_SELECTION": except ValueError: halg = TPM2_ALG.parse(hunks[0]) + pcrs: Union[Iterable[int], str] if hunks[1] != "all": try: pcrs = [int(x.strip(), 0) for x in hunks[1].split(",")] @@ -2100,7 +2125,7 @@ class TPMU_PUBLIC_ID(TPM_OBJECT): class TPMT_SENSITIVE(TPM_OBJECT): @classmethod - def from_pem(cls, data, password: Optional[bytes] = None): + def from_pem(cls, data: bytes, password: Optional[bytes] = None) -> "Self": """Decode the private part from standard key encodings. Currently supports PEM, DER and SSH encoded private keys. @@ -2119,14 +2144,14 @@ def from_pem(cls, data, password: Optional[bytes] = None): @classmethod def keyedhash_from_secret( cls, - secret, - nameAlg=TPM2_ALG.SHA256, - objectAttributes=( + secret: bytes, + nameAlg: TPM2_ALG = TPM2_ALG(TPM2_ALG.SHA256), + objectAttributes: TPMA_OBJECT = TPMA_OBJECT( TPMA_OBJECT.DECRYPT | TPMA_OBJECT.SIGN_ENCRYPT | TPMA_OBJECT.USERWITHAUTH ), scheme: Optional[TPMT_KEYEDHASH_SCHEME] = None, seed: Optional[bytes] = None, - ): + ) -> Tuple["TPMT_SENSITIVE", TPMT_PUBLIC]: """Generate the private and public part for a keyed hash object from a secret. Args: @@ -2162,15 +2187,15 @@ def keyedhash_from_secret( @classmethod def symcipher_from_secret( cls, - secret, - algorithm=TPM2_ALG.AES, - mode=TPM2_ALG.CFB, - nameAlg=TPM2_ALG.SHA256, - objectAttributes=( + secret: bytes, + algorithm: TPM2_ALG = TPM2_ALG(TPM2_ALG.AES), + mode: TPM2_ALG = TPM2_ALG(TPM2_ALG.CFB), + nameAlg: TPM2_ALG = TPM2_ALG(TPM2_ALG.SHA256), + objectAttributes: TPMA_OBJECT = TPMA_OBJECT( TPMA_OBJECT.DECRYPT | TPMA_OBJECT.SIGN_ENCRYPT | TPMA_OBJECT.USERWITHAUTH ), seed: Optional[bytes] = None, - ): + ) -> Tuple["TPMT_SENSITIVE", TPMT_PUBLIC]: """ Generate the private and public part for a symcipher object from a secret. @@ -2213,11 +2238,11 @@ def symcipher_from_secret( def _serialize( self, - encoding: str, + encoding: serialization.Encoding, public: TPMT_PUBLIC, - format: str = serialization.PrivateFormat.TraditionalOpenSSL, - password: bytes = None, - ): + format: serialization.PrivateFormat = serialization.PrivateFormat.TraditionalOpenSSL, + password: Optional[bytes] = None, + ) -> bytes: k = private_to_key(self, public) enc_alg = ( @@ -2232,7 +2257,7 @@ def _serialize( return data - def to_pem(self, public: TPMT_PUBLIC, password: bytes = None): + def to_pem(self, public: TPMT_PUBLIC, password: Optional[bytes] = None) -> bytes: """Encode the key as PEM encoded ASN.1. public(TPMT_PUBLIC): The corresponding public key. @@ -2244,7 +2269,7 @@ def to_pem(self, public: TPMT_PUBLIC, password: bytes = None): return self._serialize(serialization.Encoding.PEM, public, password=password) - def to_der(self, public: TPMT_PUBLIC): + def to_der(self, public: TPMT_PUBLIC) -> bytes: """Encode the key as DER encoded ASN.1. public(TPMT_PUBLIC): The corresponding public key. @@ -2255,7 +2280,7 @@ def to_der(self, public: TPMT_PUBLIC): return self._serialize(serialization.Encoding.DER, public) - def to_ssh(self, public: TPMT_PUBLIC, password: bytes = None): + def to_ssh(self, public: TPMT_PUBLIC, password: Optional[bytes] = None) -> bytes: """Encode the key as SSH format. public(TPMT_PUBLIC): The corresponding public key. @@ -2318,7 +2343,9 @@ class TPMU_SIGNATURE(TPM_OBJECT): class TPMT_SIGNATURE(TPM_OBJECT): - def verify_signature(self, key, data): + def verify_signature( + self, key: Union[TPMT_PUBLIC, TPM2B_PUBLIC], data: bytes + ) -> None: """ Verify a TPM generated signature against a key. @@ -2331,7 +2358,7 @@ def verify_signature(self, key, data): """ _verify_signature(self, key, data) - def __bytes__(self): + def __bytes__(self) -> bytes: """Return the underlying bytes for the signature. For RSA and HMAC signatures return the signature bytes, for ECDSA return a ASN.1 encoded signature. diff --git a/src/tpm2_pytss/utils.py b/src/tpm2_pytss/utils.py index 1e1c8a2d..812a9e28 100644 --- a/src/tpm2_pytss/utils.py +++ b/src/tpm2_pytss/utils.py @@ -32,7 +32,7 @@ def make_credential( - public: TPM2B_PUBLIC, + public: Union[TPM2B_PUBLIC, TPMT_PUBLIC], credential: Union[bytes, TPM2B_DIGEST], name: Union[TPM2B_NAME, bytes], ) -> Tuple[TPM2B_ID_OBJECT, TPM2B_ENCRYPTED_SECRET]: