Skip to content

Commit

Permalink
types: fix type hints
Browse files Browse the repository at this point in the history
Signed-off-by: Erik Larsson <who+github@cnackers.org>
  • Loading branch information
whooo committed Jan 19, 2024
1 parent b1da22e commit cfe4af4
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 116 deletions.
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@ exclude = '''
[tool.mypy]
mypy_path = "mypy_stubs"
exclude = [
'src/tpm2_pytss/internal/templates.py',
'src/tpm2_pytss/encoding.py',
'src/tpm2_pytss/policy.py',
'src/tpm2_pytss/ESAPI.py',
'src/tpm2_pytss/FAPI.py',
'src/tpm2_pytss/types.py',
'src/tpm2_pytss/internal/crypto.py',
'src/tpm2_pytss/fapi_info.py',
]
5 changes: 3 additions & 2 deletions src/tpm2_pytss/_libtpm2_pytss/ffi.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Callable, Iterable, Any, Dict
from typing import Optional, Callable, Iterable, Any, Dict, Tuple

error: type[Exception]
class CData:
Expand All @@ -13,7 +13,7 @@ class CType:
kind: str
cname: str
item: "CType"
fields: Iterable[str]
fields: Iterable[Tuple[str, Any]]

NULL: CData
def gc(cdata: CData, destructor: Callable[[CData], None], size: int = 0)-> CData: ...
Expand All @@ -28,3 +28,4 @@ def from_handle(handle: CData) -> Any: ...
def new_handle(python_object: Any) -> CData: ...
def cast(ctype: str, value: CData) -> CData: ...
def memmove(dest: CData | bytes, src: CData | bytes, n: int) -> None: ...
def addressof(cdata: CData, *fields_or_indexes: str | int) -> CData: ...
14 changes: 10 additions & 4 deletions src/tpm2_pytss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
SupportsIndex,
)

try:
from typing import Self
except ImportError:
# assume mypy is running on python 3.11+
pass


if TYPE_CHECKING:
from .ESAPI import ESAPI
Expand All @@ -33,14 +39,14 @@ class TPM_FRIENDLY_INT(int):
_FIXUP_MAP: Dict[str, str] = {}

@classmethod
def parse(cls, value: str) -> int:
def parse(cls, value: str) -> "Self":
# If it's a string initializer value, see if it matches anything in the list
if isinstance(value, str):
try:
x = _CLASS_INT_ATTRS_from_string(cls, value, cls._FIXUP_MAP)
if not isinstance(x, int):
raise KeyError(f'Expected int got: "{type(x)}"')
return x
return cls(x)
except KeyError:
raise ValueError(
f'Could not convert friendly name to value, got: "{value}"'
Expand Down Expand Up @@ -305,7 +311,7 @@ class TPMA_FRIENDLY_INTLIST(TPM_FRIENDLY_INT):
_MASKS: Tuple[Tuple[int, int, str], ...] = tuple()

@classmethod
def parse(cls, value: str) -> int:
def parse(cls, value: str) -> "Self":
""" Converts a string of | separated constant values into it's integer value.
Given a pipe "|" separated list of string constant values that represent the
Expand Down Expand Up @@ -361,7 +367,7 @@ def parse(cls, value: str) -> int:
f'Could not convert friendly name to value, got: "{k}"'
)

return intvalue
return cls(intvalue)

def __str__(self) -> str:
"""Given a constant, return the string bitwise representation.
Expand Down
26 changes: 17 additions & 9 deletions src/tpm2_pytss/internal/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from cryptography.hazmat.primitives.ciphers import modes, Cipher, CipherAlgorithm
from cryptography.hazmat.backends import default_backend
from cryptography.exceptions import UnsupportedAlgorithm, InvalidSignature
from typing import Tuple, Type, Union
from typing import Tuple, Type, Union, Optional
import secrets
import sys

Expand Down Expand Up @@ -140,7 +140,9 @@ def key_from_encoding(data, password=None):
raise ValueError("Unsupported key format")


def _public_from_encoding(data, obj, password=None):
def _public_from_encoding(
data: bytes, obj: "TPMT_PUBLIC", password: Optional[bytes] = None
) -> None:
key = key_from_encoding(data, password)
nums = key.public_numbers()
if isinstance(key, rsa.RSAPublicKey):
Expand Down Expand Up @@ -183,7 +185,9 @@ def private_key_from_encoding(data, password=None):
raise ValueError("Unsupported key format")


def _private_from_encoding(data, obj, password=None):
def _private_from_encoding(
data: bytes, obj: "TPMT_SENSITIVE", password: Optional[bytes] = None
):
key = private_key_from_encoding(data, password)
nums = key.private_numbers()
if isinstance(key, rsa.RSAPrivateKey):
Expand Down Expand Up @@ -274,7 +278,9 @@ def _generate_d(p, q, e, n):
return d


def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC"):
def private_to_key(
private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC"
) -> Union[rsa.RSAPrivateKey, ec.EllipticCurvePrivateKey]:
key = None
if private.sensitiveType == TPM2_ALG.RSA:

Expand Down Expand Up @@ -310,7 +316,7 @@ def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC")
return key


def _public_to_pem(obj, encoding="pem"):
def _public_to_pem(obj: "TPMT_PUBLIC", encoding: str = "pem") -> bytes:
encoding = encoding.lower()
key = public_to_key(obj)
if encoding == "pem":
Expand All @@ -323,7 +329,7 @@ def _public_to_pem(obj, encoding="pem"):
raise ValueError(f"unsupported encoding: {encoding}")


def _getname(obj):
def _getname(obj: "TPMT_PUBLIC") -> bytes:
dt = _get_digest(obj.nameAlg)
if dt is None:
raise ValueError(f"unsupported digest algorithm: {obj.nameAlg}")
Expand Down Expand Up @@ -391,7 +397,7 @@ def _symdef_to_crypt(symdef: "TPMT_SYM_DEF"):
return (alg, mode, bits)


def _calculate_sym_unique(nameAlg, secret, seed):
def _calculate_sym_unique(nameAlg: TPM2_ALG, secret: bytes, seed: bytes) -> bytes:
dt = _get_digest(nameAlg)
if dt is None:
raise ValueError(f"unsupported digest algorithm: {nameAlg}")
Expand All @@ -409,7 +415,7 @@ def _get_digest_size(alg: TPM2_ALG) -> int:
return dt.digest_size


def _get_signature_bytes(sig):
def _get_signature_bytes(sig: "TPMT_SIGNATURE") -> bytes:
if sig.sigAlg in (TPM2_ALG.RSAPSS, TPM2_ALG.RSASSA):
rb = bytes(sig.signature.rsapss.sig)
elif sig.sigAlg == TPM2_ALG.ECDSA:
Expand Down Expand Up @@ -476,7 +482,9 @@ def verify_signature_hmac(signature, key, data):
h.verify(sig)


def _verify_signature(signature, key, data):
def _verify_signature(
signature: "TPMT_SIGNATURE", key: "TPMT_PUBLIC", data: bytes
) -> None:
if hasattr(key, "publicArea"):
key = key.publicArea
kt = getattr(key, "type", None)
Expand Down
2 changes: 1 addition & 1 deletion src/tpm2_pytss/internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _cpointer_to_ctype(x: ffi.CData) -> ffi.CType:


def _fixup_cdata_kwargs(
this: Any, _cdata: ffi.CData, kwargs: Dict[str, Any]
this: Any, _cdata: Any, kwargs: Dict[str, Any]
) -> Tuple[ffi.CData, Dict[str, Any]]:

# folks may call this routine without a keyword argument which means it may
Expand Down
Loading

0 comments on commit cfe4af4

Please sign in to comment.