Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tpm2: Add HashToAlgorithm #226

Merged
merged 1 commit into from
Dec 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions tpm2/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,17 @@ import (
"github.com/google/go-tpm/tpmutil"
)

var hashMapping = map[Algorithm]crypto.Hash{
AlgSHA1: crypto.SHA1,
AlgSHA256: crypto.SHA256,
AlgSHA384: crypto.SHA384,
AlgSHA512: crypto.SHA512,
var hashInfo = []struct {
alg Algorithm
hash crypto.Hash
}{
{AlgSHA1, crypto.SHA1},
{AlgSHA256, crypto.SHA256},
{AlgSHA384, crypto.SHA384},
{AlgSHA512, crypto.SHA512},
{AlgSHA3_256, crypto.SHA3_256},
{AlgSHA3_384, crypto.SHA3_384},
{AlgSHA3_512, crypto.SHA3_512},
}

// MAX_DIGEST_BUFFER is the maximum size of []byte request or response fields.
Expand All @@ -43,6 +49,16 @@ const maxDigestBuffer = 1024
// Algorithm represents a TPM_ALG_ID value.
type Algorithm uint16

// HashToAlgorithm looks up the TPM2 algorithm corresponding to the provided crypto.Hash
func HashToAlgorithm(hash crypto.Hash) (Algorithm, error) {
for _, info := range hashInfo {
if info.hash == hash {
return info.alg, nil
}
}
return AlgUnknown, fmt.Errorf("go hash algorithm #%d has no TPM2 algorithm", hash)
}

// IsNull returns true if a is AlgNull or zero (unset).
func (a Algorithm) IsNull() bool {
return a == AlgNull || a == AlgUnknown
Expand All @@ -61,14 +77,15 @@ func (a Algorithm) UsesHash() bool {
// Hash returns a crypto.Hash based on the given TPM_ALG_ID.
// An error is returned if the given algorithm is not a hash algorithm or is not available.
func (a Algorithm) Hash() (crypto.Hash, error) {
hash, ok := hashMapping[a]
if !ok {
return crypto.Hash(0), fmt.Errorf("hash algorithm not supported: 0x%x", a)
}
if !hash.Available() {
return crypto.Hash(0), fmt.Errorf("go hash algorithm #%d not available", hash)
for _, info := range hashInfo {
if info.alg == a {
if !info.hash.Available() {
return crypto.Hash(0), fmt.Errorf("go hash algorithm #%d not available", info.hash)
}
return info.hash, nil
}
}
return hash, nil
return crypto.Hash(0), fmt.Errorf("hash algorithm not supported: 0x%x", a)
}

// Supported Algorithms.
Expand All @@ -94,6 +111,9 @@ const (
AlgKDF2 Algorithm = 0x0021
AlgECC Algorithm = 0x0023
AlgSymCipher Algorithm = 0x0025
AlgSHA3_256 Algorithm = 0x0027
AlgSHA3_384 Algorithm = 0x0028
AlgSHA3_512 Algorithm = 0x0029
AlgCTR Algorithm = 0x0040
AlgOFB Algorithm = 0x0041
AlgCBC Algorithm = 0x0042
Expand Down
6 changes: 3 additions & 3 deletions tpm2/structures.go
Original file line number Diff line number Diff line change
Expand Up @@ -961,9 +961,9 @@ func decodeHashValue(in *bytes.Buffer) (*HashValue, error) {
if err := tpmutil.UnpackBuf(in, &hv.Alg); err != nil {
return nil, fmt.Errorf("decoding Alg: %v", err)
}
hfn, ok := hashMapping[hv.Alg]
if !ok {
return nil, fmt.Errorf("hash algorithm not supported: 0x%x", hv.Alg)
hfn, err := hv.Alg.Hash()
if err != nil {
return nil, err
}
hv.Value = make(tpmutil.U16Bytes, hfn.Size())
if _, err := in.Read(hv.Value); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion tpm2/tpm2.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func encodeTPMLPCRSelection(sel ...PCRSelection) ([]byte, error) {
// s[i].PCRs parameter is indexes of PCRs, convert that to set bits.
for _, n := range s.PCRs {
if n >= 8*sizeOfPCRSelect {
return nil, fmt.Errorf("PCR index %d is out of range (exceeds maximum value %d)", n, 8*sizeOfPCRSelect-1)
return nil, fmt.Errorf("PCR index %d is out of range (exceeds maximum value %d)", n, 8*sizeOfPCRSelect-1)
}
byteNum := n / 8
bytePos := byte(1 << byte(n%8))
Expand Down