diff --git a/tpm2/constants.go b/tpm2/constants.go index 4276a176..716ea471 100644 --- a/tpm2/constants.go +++ b/tpm2/constants.go @@ -28,11 +28,17 @@ import ( "github.com/google/go-tpm/tpmutil" ) -var hashMapping = map[Algorithm]crypto.Hash{ - AlgSHA1: crypto.SHA1, - AlgSHA256: crypto.SHA256, - AlgSHA384: crypto.SHA384, - AlgSHA512: crypto.SHA512, +var hashInfo = []struct { + alg Algorithm + hash crypto.Hash +}{ + {AlgSHA1, crypto.SHA1}, + {AlgSHA256, crypto.SHA256}, + {AlgSHA384, crypto.SHA384}, + {AlgSHA512, crypto.SHA512}, + {AlgSHA3_256, crypto.SHA3_256}, + {AlgSHA3_384, crypto.SHA3_384}, + {AlgSHA3_512, crypto.SHA3_512}, } // MAX_DIGEST_BUFFER is the maximum size of []byte request or response fields. @@ -43,6 +49,16 @@ const maxDigestBuffer = 1024 // Algorithm represents a TPM_ALG_ID value. type Algorithm uint16 +// HashToAlgorithm looks up the TPM2 algorithm corresponding to the provided crypto.Hash +func HashToAlgorithm(hash crypto.Hash) (Algorithm, error) { + for _, info := range hashInfo { + if info.hash == hash { + return info.alg, nil + } + } + return AlgUnknown, fmt.Errorf("go hash algorithm #%d has no TPM2 algorithm", hash) +} + // IsNull returns true if a is AlgNull or zero (unset). func (a Algorithm) IsNull() bool { return a == AlgNull || a == AlgUnknown @@ -61,14 +77,15 @@ func (a Algorithm) UsesHash() bool { // Hash returns a crypto.Hash based on the given TPM_ALG_ID. // An error is returned if the given algorithm is not a hash algorithm or is not available. func (a Algorithm) Hash() (crypto.Hash, error) { - hash, ok := hashMapping[a] - if !ok { - return crypto.Hash(0), fmt.Errorf("hash algorithm not supported: 0x%x", a) - } - if !hash.Available() { - return crypto.Hash(0), fmt.Errorf("go hash algorithm #%d not available", hash) + for _, info := range hashInfo { + if info.alg == a { + if !info.hash.Available() { + return crypto.Hash(0), fmt.Errorf("go hash algorithm #%d not available", info.hash) + } + return info.hash, nil + } } - return hash, nil + return crypto.Hash(0), fmt.Errorf("hash algorithm not supported: 0x%x", a) } // Supported Algorithms. @@ -94,6 +111,9 @@ const ( AlgKDF2 Algorithm = 0x0021 AlgECC Algorithm = 0x0023 AlgSymCipher Algorithm = 0x0025 + AlgSHA3_256 Algorithm = 0x0027 + AlgSHA3_384 Algorithm = 0x0028 + AlgSHA3_512 Algorithm = 0x0029 AlgCTR Algorithm = 0x0040 AlgOFB Algorithm = 0x0041 AlgCBC Algorithm = 0x0042 diff --git a/tpm2/structures.go b/tpm2/structures.go index c90b6837..4040a23a 100644 --- a/tpm2/structures.go +++ b/tpm2/structures.go @@ -961,9 +961,9 @@ func decodeHashValue(in *bytes.Buffer) (*HashValue, error) { if err := tpmutil.UnpackBuf(in, &hv.Alg); err != nil { return nil, fmt.Errorf("decoding Alg: %v", err) } - hfn, ok := hashMapping[hv.Alg] - if !ok { - return nil, fmt.Errorf("hash algorithm not supported: 0x%x", hv.Alg) + hfn, err := hv.Alg.Hash() + if err != nil { + return nil, err } hv.Value = make(tpmutil.U16Bytes, hfn.Size()) if _, err := in.Read(hv.Value); err != nil { diff --git a/tpm2/tpm2.go b/tpm2/tpm2.go index edc970b5..00926001 100644 --- a/tpm2/tpm2.go +++ b/tpm2/tpm2.go @@ -74,7 +74,7 @@ func encodeTPMLPCRSelection(sel ...PCRSelection) ([]byte, error) { // s[i].PCRs parameter is indexes of PCRs, convert that to set bits. for _, n := range s.PCRs { if n >= 8*sizeOfPCRSelect { - return nil, fmt.Errorf("PCR index %d is out of range (exceeds maximum value %d)", n, 8*sizeOfPCRSelect-1) + return nil, fmt.Errorf("PCR index %d is out of range (exceeds maximum value %d)", n, 8*sizeOfPCRSelect-1) } byteNum := n / 8 bytePos := byte(1 << byte(n%8))