diff --git a/tpm2/constants.go b/tpm2/constants.go index 1b4a26b5..e72cd096 100644 --- a/tpm2/constants.go +++ b/tpm2/constants.go @@ -400,7 +400,9 @@ const ( TagAttestCertify tpmutil.Tag = 0x8017 TagAttestQuote tpmutil.Tag = 0x8018 TagAttestCreation tpmutil.Tag = 0x801a + TagAuthSecret tpmutil.Tag = 0x8023 TagHashCheck tpmutil.Tag = 0x8024 + TagAuthSigned tpmutil.Tag = 0x8025 ) // StartupType instructs the TPM on how to handle its state during Shutdown or @@ -470,6 +472,7 @@ const ( CmdSequenceUpdate tpmutil.Command = 0x0000015C CmdSign tpmutil.Command = 0x0000015D CmdUnseal tpmutil.Command = 0x0000015E + CmdPolicySigned tpmutil.Command = 0x00000160 CmdContextLoad tpmutil.Command = 0x00000161 CmdContextSave tpmutil.Command = 0x00000162 CmdECDHKeyGen tpmutil.Command = 0x00000163 diff --git a/tpm2/structures.go b/tpm2/structures.go index da81dc77..6df9f7f0 100644 --- a/tpm2/structures.go +++ b/tpm2/structures.go @@ -573,6 +573,27 @@ type Signature struct { ECC *SignatureECC } +// Encode serializes a Signature structure in TPM wire format. +func (s Signature) Encode() ([]byte, error) { + head, err := tpmutil.Pack(s.Alg) + if err != nil { + return nil, fmt.Errorf("encoding Alg: %v", err) + } + var signature []byte + switch s.Alg { + case AlgRSASSA, AlgRSAPSS: + if signature, err = tpmutil.Pack(s.RSA); err != nil { + return nil, fmt.Errorf("encoding RSA: %v", err) + } + case AlgECDSA: + signature, err = tpmutil.Pack(s.ECC.HashAlg, tpmutil.U16Bytes(s.ECC.R.Bytes()), tpmutil.U16Bytes(s.ECC.S.Bytes())) + if err != nil { + return nil, fmt.Errorf("encoding ECC: %v", err) + } + } + return concat(head, signature) +} + // DecodeSignature decodes a serialized TPMT_SIGNATURE structure. func DecodeSignature(in *bytes.Buffer) (*Signature, error) { var sig Signature diff --git a/tpm2/test/tpm2_test.go b/tpm2/test/tpm2_test.go index 3a845a52..4ac45864 100644 --- a/tpm2/test/tpm2_test.go +++ b/tpm2/test/tpm2_test.go @@ -23,10 +23,13 @@ import ( "crypto/rsa" "crypto/sha1" "crypto/sha256" + "encoding/binary" "flag" "fmt" "hash" "io" + "math" + "math/big" "reflect" "strings" "testing" @@ -75,8 +78,34 @@ var ( ExponentRaw: 1<<16 + 1, }, } - defaultPassword = "\x01\x02\x03\x04" - emptyPassword = "" + defaultPassword = "\x01\x02\x03\x04" + emptyPassword = "" + defaultRsaSignerParams = Public{ + Type: AlgRSA, + NameAlg: AlgSHA256, + Attributes: FlagSign | FlagSensitiveDataOrigin | FlagUserWithAuth, + RSAParameters: &RSAParams{ + Sign: &SigScheme{ + Alg: AlgRSASSA, + Hash: AlgSHA256, + }, + KeyBits: 2048, + }, + } + defaultEccSignerParams = Public{ + Type: AlgECC, + NameAlg: AlgSHA256, + Attributes: FlagSign | FlagSensitiveDataOrigin | FlagUserWithAuth, + ECCParameters: &ECCParams{ + Sign: &SigScheme{ + Alg: AlgECDSA, + Hash: AlgSHA256, + }, + CurveID: CurveNISTP256, + }, + } + nullTicketSigned = Ticket{Type: TagAuthSigned, Hierarchy: HandleNull} + nullTicketSecret = Ticket{Type: TagAuthSecret, Hierarchy: HandleNull} ) func min(a, b int) int { @@ -1179,6 +1208,69 @@ func TestEncodeDecodePublicDefaultRSAExponent(t *testing.T) { } } +func TestEncodeDecodeSignature(t *testing.T) { + randRSASig := func() []byte { + // Key size 2048 bits + var size uint16 = 256 + sizeU16 := make([]byte, 2) + binary.BigEndian.PutUint16(sizeU16, size) + key := make([]byte, size) + rand.Read(key) + return append(sizeU16, key...) + } + + run := func(t *testing.T, s Signature) { + e, err := s.Encode() + if err != nil { + t.Fatalf("Signature{%+v}.Encode() returned error: %v", s, err) + } + d, err := DecodeSignature(bytes.NewBuffer(e)) + if err != nil { + t.Fatalf("DecodeSignature{%v} returned error: %v", e, err) + } + if !reflect.DeepEqual(s, *d) { + t.Errorf("got decoded value:\n%v\nwant:\n%v", d, s) + } + } + t.Run("RSASSA", func(t *testing.T) { + run(t, Signature{ + Alg: AlgRSASSA, + RSA: &SignatureRSA{ + HashAlg: AlgSHA256, + Signature: randRSASig(), + }, + }) + }) + t.Run("RSAPSS", func(t *testing.T) { + run(t, Signature{ + Alg: AlgRSAPSS, + RSA: &SignatureRSA{ + HashAlg: AlgSHA256, + Signature: randRSASig(), + }, + }) + }) + t.Run("ECDSA", func(t *testing.T) { + // Key size 256 bits + size := 32 + randBytes := make([]byte, size) + rand.Read(randBytes) + r := big.NewInt(0).SetBytes(randBytes) + + rand.Read(randBytes) + s := big.NewInt(0).SetBytes(randBytes) + + run(t, Signature{ + Alg: AlgECDSA, + ECC: &SignatureECC{ + HashAlg: AlgSHA256, + R: r, + S: s, + }, + }) + }) +} + func TestCreateKeyWithSensitive(t *testing.T) { rw := openTPM(t) defer rw.Close() @@ -1426,15 +1518,102 @@ func TestPolicySecret(t *testing.T) { rw := openTPM(t) defer rw.Close() + expirations := []int32{math.MinInt32, math.MinInt32 + 1, -1, 0, 1, math.MaxInt32} + for _, expiration := range expirations { + t.Run(t.Name()+fmt.Sprint(expiration), func(t *testing.T) { + _, tkt := testPolicySecret(t, rw, expiration) + // Part 3: policyTicket is produced if the command succeeds and expiration in + // the command was non-zero. + // If expiration is non-negative, a NULL Ticket is returned. + if expiration < 0 && len(tkt.Digest) == 0 { + t.Fatalf("Got empty ticket digest, expected ticket with auth expiry") + } else if expiration >= 0 && !reflect.DeepEqual(*tkt, nullTicketSecret) { + t.Fatalf("Got ticket with non-empty digest, expected NULL ticket") + } + }) + } +} + +func testPolicySecret(t *testing.T, rw io.ReadWriter, expiration int32) ([]byte, *Ticket) { sessHandle, _, err := StartAuthSession(rw, HandleNull, HandleNull, make([]byte, 16), nil, SessionPolicy, AlgNull, AlgSHA256) if err != nil { t.Fatalf("StartAuthSession() failed: %v", err) } defer FlushContext(rw, sessHandle) - if _, err := PolicySecret(rw, HandleEndorsement, AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession}, sessHandle, nil, nil, nil, 0); err != nil { + timeout, tkt, err := PolicySecret(rw, HandleEndorsement, AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession}, sessHandle, nil, nil, nil, expiration) + if err != nil { t.Fatalf("PolicySecret() failed: %v", err) } + return timeout, tkt +} + +func TestPolicySigned(t *testing.T) { + rw := openTPM(t) + defer rw.Close() + + signers := map[string]Public{ + "RSA": defaultRsaSignerParams, + "ECC": defaultEccSignerParams, + } + + expirations := []int32{math.MinInt32, math.MinInt32 + 1, -1, 0, 1, math.MaxInt32} + for _, expiration := range expirations { + for signerType, params := range signers { + t.Run(t.Name()+signerType+fmt.Sprint(expiration), func(t *testing.T) { + _, tkt := testPolicySigned(t, rw, expiration, params) + // Part 3: policyTicket is produced if the command succeeds and expiration in + // the command was non-zero. + // If expiration is non-negative, a NULL Ticket is returned. + if expiration < 0 && len(tkt.Digest) == 0 { + t.Fatalf("Got empty ticket digest, expected ticket with auth expiry") + } else if expiration >= 0 && !reflect.DeepEqual(*tkt, nullTicketSigned) { + t.Fatalf("Got ticket with non-empty digest, expected NULL ticket") + } + }) + } + } +} + +func testPolicySigned(t *testing.T, rw io.ReadWriter, expiration int32, signerParams Public) ([]byte, *Ticket) { + handle, _, err := CreatePrimary(rw, HandleOwner, PCRSelection{}, emptyPassword, emptyPassword, signerParams) + if err != nil { + t.Fatalf("CreatePrimary() failed: %s", err) + } + defer FlushContext(rw, handle) + + sessHandle, nonce, err := StartAuthSession(rw, HandleNull, HandleNull, make([]byte, 16), nil, SessionPolicy, AlgNull, AlgSHA256) + if err != nil { + t.Fatalf("StartAuthSession() failed: %v", err) + } + defer FlushContext(rw, sessHandle) + + // Sign the hash of the command parameters, as described in the TPM 2.0 spec, Part 3, Section 23.3. + // We only use expiration here. + expBytes := make([]byte, 4) + binary.BigEndian.PutUint32(expBytes, uint32(expiration)) + + // TPM2.0 spec, Revision 1.38, Part 3 nonce must be present if expiration is non-zero. + // aHash ≔ HauthAlg(nonceTPM || expiration || cpHashA || policyRef) + toDigest := append(nonce, expBytes...) + + digest := sha256.Sum256(toDigest) + + sig, err := Sign(rw, handle, emptyPassword, digest[:], nil, nil) + if err != nil { + t.Fatalf("Sign failed: %s", err) + } + + signature, err := sig.Encode() + if err != nil { + t.Fatalf("Encode() failed: %v", err) + } + + timeout, tkt, err := PolicySigned(rw, handle, sessHandle, nonce, nil, nil, expiration, signature) + if err != nil { + t.Fatalf("PolicySigned() failed: %v", err) + } + return timeout, tkt } func TestQuote(t *testing.T) { diff --git a/tpm2/tpm2.go b/tpm2/tpm2.go index 30579116..c09b810b 100644 --- a/tpm2/tpm2.go +++ b/tpm2/tpm2.go @@ -747,38 +747,71 @@ func encodePolicySecret(entityHandle tpmutil.Handle, entityAuth AuthCommand, pol return concat(handles, auth, params) } -func decodePolicySecret(in []byte) (*Ticket, error) { +func decodePolicySecret(in []byte) ([]byte, *Ticket, error) { buf := bytes.NewBuffer(in) var paramSize uint32 var timeout tpmutil.U16Bytes if err := tpmutil.UnpackBuf(buf, ¶mSize, &timeout); err != nil { - return nil, fmt.Errorf("decoding timeout: %v", err) + return nil, nil, fmt.Errorf("decoding timeout: %v", err) } var t Ticket if err := tpmutil.UnpackBuf(buf, &t); err != nil { - return nil, fmt.Errorf("decoding ticket: %v", err) + return nil, nil, fmt.Errorf("decoding ticket: %v", err) } - return &t, nil + return timeout, &t, nil } // PolicySecret sets a secret authorization requirement on the provided entity. -// If expiry is non-zero, the authorization is valid for expiry seconds. -func PolicySecret(rw io.ReadWriter, entityHandle tpmutil.Handle, entityAuth AuthCommand, policyHandle tpmutil.Handle, policyNonce, cpHash, policyRef []byte, expiry int32) (*Ticket, error) { +func PolicySecret(rw io.ReadWriter, entityHandle tpmutil.Handle, entityAuth AuthCommand, policyHandle tpmutil.Handle, policyNonce, cpHash, policyRef []byte, expiry int32) ([]byte, *Ticket, error) { Cmd, err := encodePolicySecret(entityHandle, entityAuth, policyHandle, policyNonce, cpHash, policyRef, expiry) if err != nil { - return nil, err + return nil, nil, err } resp, err := runCommand(rw, TagSessions, CmdPolicySecret, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, nil, err + } + return decodePolicySecret(resp) +} + +func encodePolicySigned(validationKeyHandle tpmutil.Handle, policyHandle tpmutil.Handle, policyNonce, cpHash, policyRef tpmutil.U16Bytes, expiry int32, auth []byte) ([]byte, error) { + handles, err := tpmutil.Pack(validationKeyHandle, policyHandle) if err != nil { return nil, err } + params, err := tpmutil.Pack(policyNonce, cpHash, policyRef, expiry, auth) + if err != nil { + return nil, err + } + return concat(handles, params) +} + +func decodePolicySigned(in []byte) ([]byte, *Ticket, error) { + buf := bytes.NewBuffer(in) - // Tickets are only provided if expiry is set. - if expiry != 0 { - return decodePolicySecret(resp) + var timeout tpmutil.U16Bytes + if err := tpmutil.UnpackBuf(buf, &timeout); err != nil { + return nil, nil, fmt.Errorf("decoding timeout: %v", err) + } + var t Ticket + if err := tpmutil.UnpackBuf(buf, &t); err != nil { + return nil, nil, fmt.Errorf("decoding ticket: %v", err) + } + return timeout, &t, nil +} + +// PolicySigned sets a signed authorization requirement on the provided policy. +func PolicySigned(rw io.ReadWriter, validationKeyHandle tpmutil.Handle, policyHandle tpmutil.Handle, policyNonce, cpHash, policyRef []byte, expiry int32, signedAuth []byte) ([]byte, *Ticket, error) { + Cmd, err := encodePolicySigned(validationKeyHandle, policyHandle, policyNonce, cpHash, policyRef, expiry, signedAuth) + if err != nil { + return nil, nil, err + } + resp, err := runCommand(rw, TagNoSessions, CmdPolicySigned, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, nil, err } - return nil, nil + return decodePolicySigned(resp) } func encodePolicyPCR(session tpmutil.Handle, expectedDigest tpmutil.U16Bytes, sel PCRSelection) ([]byte, error) {