diff --git a/README.md b/README.md index 3063ecd..00434b2 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,8 @@ Server options: Voucher guid to extend for resale -resale-key path The path to a PEM-encoded x.509 public key for the next owner + -reuse-cred + Perform the Credential Reuse Protocol in TO2 -rv-bypass Skip TO1 -rv-delay seconds @@ -171,6 +173,10 @@ client error: transfer of ownership not successful exit status 2 ``` +To test repeatedly without the device credential changing, run the server with the `-reuse-cred` flag to enable the [Credential Reuse Protocol][Credential Reuse Protocol]. + +[Credential Reuse Protocol]: https://fidoalliance.org/specs/FDO/FIDO-Device-Onboard-PS-v1.1-20220419/FIDO-Device-Onboard-PS-v1.1-20220419.html#credreuse + ### Testing RV Blob Registration First, start a server in a separate console. @@ -230,7 +236,9 @@ $ go run ./examples/cmd server -http 127.0.0.1:9999 -db ./test.db external: 127.0.0.1:9999 ``` -Then DI, followed by TO1 and TO2 may be run. To use ASYMKEX* key exchange, the device key must be RSA. To specify the device key type, use `-di-key` when running DI. +Then DI, followed by TO1 and TO2 may be run. + +Because in the example the device key type and owner key type will always match and to use ASYMKEX\* key exchange the owner key must be RSA, the device key must also be RSA. To specify the device key type, use `-di-key` when running DI. ```console $ go run ./examples/cmd client -di http://127.0.0.1:9999 -di-key rsa2048 @@ -255,6 +263,7 @@ Next, initialize the device and perform transfer of ownership. ```console $ go run ./examples/cmd client -di http://127.0.0.1:9999 $ go run ./examples/cmd client +Success $ go run ./examples/cmd client -print blobcred[ ... diff --git a/cose/sign_test.go b/cose/sign_test.go index ebfbf3a..5365480 100644 --- a/cose/sign_test.go +++ b/cose/sign_test.go @@ -12,8 +12,10 @@ import ( "math/big" "testing" + "github.com/fido-device-onboard/go-fdo" "github.com/fido-device-onboard/go-fdo/cbor" "github.com/fido-device-onboard/go-fdo/cose" + "github.com/fido-device-onboard/go-fdo/protocol" ) func TestSignAndVerify(t *testing.T) { @@ -38,7 +40,7 @@ func TestSignAndVerify(t *testing.T) { cose.Label{Int64: 4}: []byte("11"), }, }, - Payload: cbor.NewByteWrap[[]byte]([]byte("This is the content.")), + Payload: cbor.NewByteWrap([]byte("This is the content.")), } externalAAD, _ := hex.DecodeString("11aa22bb33cc44dd55006699") @@ -73,7 +75,7 @@ func TestSignAndVerify(t *testing.T) { } s1 := cose.Sign1[[]byte, []byte]{ - Payload: cbor.NewByteWrap[[]byte]([]byte("This is the content.")), + Payload: cbor.NewByteWrap([]byte("This is the content.")), } if err := s1.Sign(key384, nil, nil, nil); err != nil { t.Fatalf("error signing: %v", err) @@ -101,3 +103,43 @@ func TestSignAndVerify(t *testing.T) { } }) } + +// Request 255: [101, 61, "cryptographic verification failed: TO2.ProveOVHdr payload signature verification failed", 1727891427, null] +func TestSomethingThatFailedSignatureVerificationOnceInCIForUnknownReasons(t *testing.T) { + // 18([h'a101390100', {256: h'b2d33efa8e5cea10ea364043bc381bc3', 257: [1, 1, h'30820122300d06092a864886f70d01010105000382010f003082010a0282010100d3e882bc85ebe378b5c043f5f51135f39531c5708fb0a455fb680eff25070502ad3f333de6e1bbaac4c133107f125c8056047d4c77dbdde178eb92b43432f249f7ca080be18b04662d03f4d28873b9569094d50b036d4b8b65eee101ec54b2f834a45e4e297464dc231c74e643ec99fa84b49363d3aa7bb5e73aa96b0c74c886c132f997aea110b4f5b89451a52bfa651d50fcabfde7fb570a99f744f849afdc27732f5bdee138ea2d2ae0e95bc010eae36c9eee7286cc615844d7a84946d4b8c6653563004b528771734f30bff2af9c699d9cf23477663c231f936670aa64bbdd4ab4367a62ab34a5dfb44ca03d4ecc74c28e33803b3ca4a04c0271bbe6d1ad0203010001']}, h'8859017186186550ed5c309ac00d13f29b22912649fac98e806b746573745f64657669636583010159012630820122300d06092a864886f70d01010105000382010f003082010a0282010100d3e882bc85ebe378b5c043f5f51135f39531c5708fb0a455fb680eff25070502ad3f333de6e1bbaac4c133107f125c8056047d4c77dbdde178eb92b43432f249f7ca080be18b04662d03f4d28873b9569094d50b036d4b8b65eee101ec54b2f834a45e4e297464dc231c74e643ec99fa84b49363d3aa7bb5e73aa96b0c74c886c132f997aea110b4f5b89451a52bfa651d50fcabfde7fb570a99f744f849afdc27732f5bdee138ea2d2ae0e95bc010eae36c9eee7286cc615844d7a84946d4b8c6653563004b528771734f30bff2af9c699d9cf23477663c231f936670aa64bbdd4ab4367a62ab34a5dfb44ca03d4ecc74c28e33803b3ca4a04c0271bbe6d1ad0203010001822f58209f17599e0a16082abaf313f448add12acd14c981a3dfa786d240c842113d974000820558206552c303917e65450b187727bb6df531c819421e7148790c045b52dcc1dfbc9d509b5473c6ed93fd29c5507fafc0b5824082390100405820829da9590248b6b8f9b559bb2b5bac3ce88984963fc0fae842a5b5f07c3b15e5822f58206ebb2e1467c7162bb953c36092ad805207a8474ccd18b06267198b184c34eaf419ffff', h'881e74d84932c8986341f8423801f43aab92a813f53ee9902cc5d2ebf48f4ea23ca84fe52f709b1c86b6a17295b605b5d5d1e876069cc0bb7fd9115f16f6e7aceb43c4997053161ca1117110e24ea83afb9bf2092dc1e921dac0ecd533fd33b1e6f6e48a04d085d8a3b9552c6a447f39249509de11d2a52f09b13736d0fee2afe63af26ac6a56b615ed7f937b6b087a3d1105c0e07326cd76c8974e12f75c6dc91b18ec08cdded88b9b32b803becb37757210682c9d975be507c8364ad4ae99e5a903db04ab5f94baa039168d070f641f3685437f32972cb79d4f92fcdc47045d9cdcb9385de1dce1421d3cbf09cd73d34775775e4300c7454ada07c92d38613']) + data, _ := hex.DecodeString("D28445A101390100A219010050B2D33EFA8E5CEA10EA364043BC381BC319010183010159012630820122300D06092A864886F70D01010105000382010F003082010A0282010100D3E882BC85EBE378B5C043F5F51135F39531C5708FB0A455FB680EFF25070502AD3F333DE6E1BBAAC4C133107F125C8056047D4C77DBDDE178EB92B43432F249F7CA080BE18B04662D03F4D28873B9569094D50B036D4B8B65EEE101EC54B2F834A45E4E297464DC231C74E643EC99FA84B49363D3AA7BB5E73AA96B0C74C886C132F997AEA110B4F5B89451A52BFA651D50FCABFDE7FB570A99F744F849AFDC27732F5BDEE138EA2D2AE0E95BC010EAE36C9EEE7286CC615844D7A84946D4B8C6653563004B528771734F30BFF2AF9C699D9CF23477663C231F936670AA64BBDD4AB4367A62AB34A5DFB44CA03D4ECC74C28E33803B3CA4A04C0271BBE6D1AD02030100015901F98859017186186550ED5C309AC00D13F29B22912649FAC98E806B746573745F64657669636583010159012630820122300D06092A864886F70D01010105000382010F003082010A0282010100D3E882BC85EBE378B5C043F5F51135F39531C5708FB0A455FB680EFF25070502AD3F333DE6E1BBAAC4C133107F125C8056047D4C77DBDDE178EB92B43432F249F7CA080BE18B04662D03F4D28873B9569094D50B036D4B8B65EEE101EC54B2F834A45E4E297464DC231C74E643EC99FA84B49363D3AA7BB5E73AA96B0C74C886C132F997AEA110B4F5B89451A52BFA651D50FCABFDE7FB570A99F744F849AFDC27732F5BDEE138EA2D2AE0E95BC010EAE36C9EEE7286CC615844D7A84946D4B8C6653563004B528771734F30BFF2AF9C699D9CF23477663C231F936670AA64BBDD4AB4367A62AB34A5DFB44CA03D4ECC74C28E33803B3CA4A04C0271BBE6D1AD0203010001822F58209F17599E0A16082ABAF313F448ADD12ACD14C981A3DFA786D240C842113D974000820558206552C303917E65450B187727BB6DF531C819421E7148790C045B52DCC1DFBC9D509B5473C6ED93FD29C5507FAFC0B5824082390100405820829DA9590248B6B8F9B559BB2B5BAC3CE88984963FC0FAE842A5B5F07C3B15E5822F58206EBB2E1467C7162BB953C36092AD805207A8474CCD18B06267198B184C34EAF419FFFF590100881E74D84932C8986341F8423801F43AAB92A813F53EE9902CC5D2EBF48F4EA23CA84FE52F709B1C86B6A17295B605B5D5D1E876069CC0BB7FD9115F16F6E7ACEB43C4997053161CA1117110E24EA83AFB9BF2092DC1E921DAC0ECD533FD33B1E6F6E48A04D085D8A3B9552C6A447F39249509DE11D2A52F09B13736D0FEE2AFE63AF26AC6A56B615ED7F937B6B087A3D1105C0E07326CD76C8974E12F75C6DC91B18EC08CDDED88B9B32B803BECB37757210682C9D975BE507C8364AD4AE99E5A903DB04AB5F94BAA039168D070F641F3685437F32972CB79D4F92FCDC47045D9CDCB9385DE1DCE1421D3CBF09CD73D34775775E4300C7454ADA07C92D38613") + type ovhProof struct { + OVH cbor.Bstr[fdo.VoucherHeader] + NumOVEntries uint8 + OVHHmac protocol.Hmac + NonceTO2ProveOV protocol.Nonce + SigInfoB struct { + Type cose.SignatureAlgorithm + Info []byte + } + KeyExchangeA []byte + HelloDeviceHash protocol.Hash + MaxOwnerMessageSize uint16 + } + var proveOVHdr cose.Sign1Tag[ovhProof, []byte] + if err := cbor.Unmarshal(data, &proveOVHdr); err != nil { + t.Fatal(err) + } + + var ownerPubKey protocol.PublicKey + if ok, err := proveOVHdr.Unprotected.Parse(cose.Label{Int64: 257}, &ownerPubKey); err != nil { + t.Fatal(err) + } else if !ok { + t.Fatal("expected pub key in unprotected header") + } + + key, err := ownerPubKey.Public() + if err != nil { + t.Fatal(err) + } + if ok, err := proveOVHdr.Verify(key, nil, nil); err != nil { + t.Fatal(err) + } else if !ok { + t.Fatal("verification failed") + } +} diff --git a/examples/cmd/client.go b/examples/cmd/client.go index 8e4b945..424b200 100644 --- a/examples/cmd/client.go +++ b/examples/cmd/client.go @@ -205,17 +205,20 @@ func client() error { FileSep: ";", Bin: runtime.GOARCH, }, - KeyExchange: kex.Suite(kexSuite), - CipherSuite: kexCipherSuiteID, + KeyExchange: kex.Suite(kexSuite), + CipherSuite: kexCipherSuiteID, + AllowCredentialReuse: true, }) if rvOnly { return nil } if newDC == nil { - return fmt.Errorf("transfer of ownership not successful") + fmt.Println("Credential not updated (either due to failure of TO2 or the Credential Reuse Protocol") + return nil } // Store new credential + fmt.Println("Success") return updateCred(*newDC) } diff --git a/examples/cmd/server.go b/examples/cmd/server.go index 477fd75..d1696cd 100644 --- a/examples/cmd/server.go +++ b/examples/cmd/server.go @@ -56,6 +56,7 @@ var ( to0GUID string resaleGUID string resaleKey string + reuseCred bool rvBypass bool rvDelay int printOwnerPubKey string @@ -87,6 +88,7 @@ func init() { serverFlags.StringVar(&addr, "http", "localhost:8080", "The `addr`ess to listen on") serverFlags.StringVar(&resaleGUID, "resale-guid", "", "Voucher `guid` to extend for resale") serverFlags.StringVar(&resaleKey, "resale-key", "", "The `path` to a PEM-encoded x.509 public key for the next owner") + serverFlags.BoolVar(&reuseCred, "reuse-cred", false, "Perform the Credential Reuse Protocol in TO2") serverFlags.BoolVar(&insecureTLS, "insecure-tls", false, "Listen with a self-signed TLS certificate") serverFlags.BoolVar(&rvBypass, "rv-bypass", false, "Skip TO1") serverFlags.IntVar(&rvDelay, "rv-delay", 0, "Delay TO1 by N `seconds`") @@ -512,11 +514,12 @@ func newHandler(rvInfo [][]protocol.RvInstruction, state *sqlite.DB) (*transport RVBlobs: state, }, TO2Responder: &fdo.TO2Server{ - Session: state, - Vouchers: state, - OwnerKeys: state, - RvInfo: func(context.Context, fdo.Voucher) ([][]protocol.RvInstruction, error) { return rvInfo, nil }, - OwnerModules: ownerModules, + Session: state, + Vouchers: state, + OwnerKeys: state, + RvInfo: func(context.Context, fdo.Voucher) ([][]protocol.RvInstruction, error) { return rvInfo, nil }, + OwnerModules: ownerModules, + ReuseCredential: func(context.Context, fdo.Voucher) bool { return reuseCred }, }, }, nil } diff --git a/examples/plugins/plugins_test.go b/examples/plugins/plugins_test.go index a4f4743..f6cb26d 100644 --- a/examples/plugins/plugins_test.go +++ b/examples/plugins/plugins_test.go @@ -39,23 +39,26 @@ func TestDownloadOwnerPlugin(t *testing.T) { downloadOwnerCmd.Stderr = fdotest.TestingLog(t) downloadOwnerPlugin := &plugin.OwnerModule{Module: plugin.NewCommandPluginModule(downloadOwnerCmd)} - fdotest.RunClientTestSuite(t, nil, nil, map[string]serviceinfo.DeviceModule{ - "fdo.download": &fsim.Download{ - CreateTemp: func() (*os.File, error) { - return os.CreateTemp(".", "fdo.download_*") + fdotest.RunClientTestSuite(t, fdotest.Config{ + DeviceModules: map[string]serviceinfo.DeviceModule{ + "fdo.download": &fsim.Download{ + CreateTemp: func() (*os.File, error) { + return os.CreateTemp(".", "fdo.download_*") + }, + NameToPath: func(name string) string { + return filepath.Join("testdata", "downloads", name) + }, + ErrorLog: fdotest.TestingLog(t), }, - NameToPath: func(name string) string { - return filepath.Join("testdata", "downloads", name) - }, - ErrorLog: fdotest.TestingLog(t), }, - }, func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { - return func(yield func(string, serviceinfo.OwnerModule) bool) { - if !yield("fdo.download", downloadOwnerPlugin) { - return + OwnerModules: func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { + return func(yield func(string, serviceinfo.OwnerModule) bool) { + if !yield("fdo.download", downloadOwnerPlugin) { + return + } } - } - }, nil) + }, + }) // Validate expected contents downloadContents, err := os.ReadFile("testdata/downloads/bigfile.test") @@ -83,19 +86,22 @@ func TestDownloadDevicePlugin(t *testing.T) { Module: plugin.NewCommandPluginModule(downloadDeviceCmd), } - fdotest.RunClientTestSuite(t, nil, nil, map[string]serviceinfo.DeviceModule{ - "fdo.download": downloadDevicePlugin, - }, func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { - return func(yield func(string, serviceinfo.OwnerModule) bool) { - if !yield("fdo.download", &fsim.DownloadContents[*bytes.Reader]{ - Name: "bigfile.test", - Contents: bytes.NewReader(expected), - MustDownload: true, - }) { - return + fdotest.RunClientTestSuite(t, fdotest.Config{ + DeviceModules: map[string]serviceinfo.DeviceModule{ + "fdo.download": downloadDevicePlugin, + }, + OwnerModules: func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { + return func(yield func(string, serviceinfo.OwnerModule) bool) { + if !yield("fdo.download", &fsim.DownloadContents[*bytes.Reader]{ + Name: "bigfile.test", + Contents: bytes.NewReader(expected), + MustDownload: true, + }) { + return + } } - } - }, nil) + }, + }) // Validate expected contents downloadContents, err := os.ReadFile("testdata/downloads/bigfile.test") @@ -146,12 +152,15 @@ func TestDevmodPlugin(t *testing.T) { var got serviceinfo.Devmod - fdotest.RunClientTestSuite(t, nil, nil, map[string]serviceinfo.DeviceModule{ - "devmod": devmodPlugin, - }, func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { - got = devmod - return func(yield func(string, serviceinfo.OwnerModule) bool) {} - }, nil) + fdotest.RunClientTestSuite(t, fdotest.Config{ + DeviceModules: map[string]serviceinfo.DeviceModule{ + "devmod": devmodPlugin, + }, + OwnerModules: func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { + got = devmod + return func(yield func(string, serviceinfo.OwnerModule) bool) {} + }, + }) if !reflect.DeepEqual(got, expected) { t.Errorf("devmod did not match expected\nwant %+v\ngot %+v", expected, got) diff --git a/fdo_test.go b/fdo_test.go index adff392..4b19785 100644 --- a/fdo_test.go +++ b/fdo_test.go @@ -25,7 +25,7 @@ import ( const mockModuleName = "fdotest.mock" func TestClient(t *testing.T) { - fdotest.RunClientTestSuite(t, nil, nil, nil, nil, nil) + fdotest.RunClientTestSuite(t, fdotest.Config{}) } func TestClientWithMockModule(t *testing.T) { @@ -47,13 +47,16 @@ func TestClientWithMockModule(t *testing.T) { }, } - fdotest.RunClientTestSuite(t, nil, nil, map[string]serviceinfo.DeviceModule{ - mockModuleName: deviceModule, - }, func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { - return func(yield func(string, serviceinfo.OwnerModule) bool) { - yield(mockModuleName, ownerModule) - } - }, nil) + fdotest.RunClientTestSuite(t, fdotest.Config{ + DeviceModules: map[string]serviceinfo.DeviceModule{ + mockModuleName: deviceModule, + }, + OwnerModules: func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { + return func(yield func(string, serviceinfo.OwnerModule) bool) { + yield(mockModuleName, ownerModule) + } + }, + }) if !deviceModule.ActiveState { t.Error("device module should be active") @@ -82,19 +85,23 @@ func TestClientWithMockModuleAndAutoUnchunking(t *testing.T) { }, } - fdotest.RunClientTestSuite(t, nil, nil, map[string]serviceinfo.DeviceModule{ - mockModuleName: deviceModule, - }, func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { - return func(yield func(string, serviceinfo.OwnerModule) bool) { - yield(mockModuleName, ownerModule) - } - }, func(t *testing.T, err error) { - if err == nil { - t.Error("expected err to occur when not handling all message chunks") - } - if !strings.Contains(err.Error(), "device module did not read full body") { - t.Error("expected err to refer to device module not reading full message body") - } + fdotest.RunClientTestSuite(t, fdotest.Config{ + DeviceModules: map[string]serviceinfo.DeviceModule{ + mockModuleName: deviceModule, + }, + OwnerModules: func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { + return func(yield func(string, serviceinfo.OwnerModule) bool) { + yield(mockModuleName, ownerModule) + } + }, + CustomExpect: func(t *testing.T, err error) { + if err == nil { + t.Error("expected err to occur when not handling all message chunks") + } + if !strings.Contains(err.Error(), "device module did not read full body") { + t.Error("expected err to refer to device module not reading full message body") + } + }, }) if !deviceModule.ActiveState { @@ -134,12 +141,15 @@ func TestClientWithCustomDevmod(t *testing.T) { }, } - fdotest.RunClientTestSuite(t, nil, nil, map[string]serviceinfo.DeviceModule{ - "devmod": customDevmod, - }, nil, func(t *testing.T, err error) { - if err == nil || !strings.Contains(err.Error(), "missing required devmod field: bin") { - t.Fatalf("expected invalid devmod error, got: %v", err) - } + fdotest.RunClientTestSuite(t, fdotest.Config{ + DeviceModules: map[string]serviceinfo.DeviceModule{ + "devmod": customDevmod, + }, + CustomExpect: func(t *testing.T, err error) { + if err == nil || !strings.Contains(err.Error(), "missing required devmod field: bin") { + t.Fatalf("expected invalid devmod error, got: %v", err) + } + }, }) }) @@ -172,9 +182,11 @@ func TestClientWithCustomDevmod(t *testing.T) { }, } - fdotest.RunClientTestSuite(t, nil, nil, map[string]serviceinfo.DeviceModule{ - "devmod": customDevmod, - }, nil, nil) + fdotest.RunClientTestSuite(t, fdotest.Config{ + DeviceModules: map[string]serviceinfo.DeviceModule{ + "devmod": customDevmod, + }, + }) }) } @@ -183,49 +195,52 @@ func TestClientWithPluginModule(t *testing.T) { devicePlugin.Routines = fdotest.ModuleNameOnlyRoutines(mockModuleName) ownerPlugins := make(chan *fdotest.MockPlugin, 1000) - fdotest.RunClientTestSuite(t, nil, nil, map[string]serviceinfo.DeviceModule{ - mockModuleName: struct { - plugin.Module - serviceinfo.DeviceModule - }{ - Module: devicePlugin, - DeviceModule: &fdotest.MockDeviceModule{ - TransitionFunc: func(active bool) error { - if active { - _, _, err := devicePlugin.Start() - return err - } - return nil - }, - }, - }, - }, func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { - return func(yield func(string, serviceinfo.OwnerModule) bool) { - var once sync.Once - ownerPlugin := new(fdotest.MockPlugin) - ownerPlugin.Routines = fdotest.ModuleNameOnlyRoutines(mockModuleName) - if !yield(mockModuleName, struct { + fdotest.RunClientTestSuite(t, fdotest.Config{ + DeviceModules: map[string]serviceinfo.DeviceModule{ + mockModuleName: struct { plugin.Module - serviceinfo.OwnerModule + serviceinfo.DeviceModule }{ - Module: ownerPlugin, - OwnerModule: &fdotest.MockOwnerModule{ - ProduceInfoFunc: func(ctx context.Context, producer *serviceinfo.Producer) (blockPeer, moduleDone bool, err error) { - once.Do(func() { _, _, err = ownerPlugin.Start() }) - if err != nil { - return false, false, err + Module: devicePlugin, + DeviceModule: &fdotest.MockDeviceModule{ + TransitionFunc: func(active bool) error { + if active { + _, _, err := devicePlugin.Start() + return err } - return false, true, producer.WriteChunk("active", []byte{0xf5}) + return nil }, }, - }) { - return - } - if slices.Contains(supportedMods, mockModuleName) { - ownerPlugins <- ownerPlugin + }, + }, + OwnerModules: func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { + return func(yield func(string, serviceinfo.OwnerModule) bool) { + var once sync.Once + ownerPlugin := new(fdotest.MockPlugin) + ownerPlugin.Routines = fdotest.ModuleNameOnlyRoutines(mockModuleName) + if !yield(mockModuleName, struct { + plugin.Module + serviceinfo.OwnerModule + }{ + Module: ownerPlugin, + OwnerModule: &fdotest.MockOwnerModule{ + ProduceInfoFunc: func(ctx context.Context, producer *serviceinfo.Producer) (blockPeer, moduleDone bool, err error) { + once.Do(func() { _, _, err = ownerPlugin.Start() }) + if err != nil { + return false, false, err + } + return false, true, producer.WriteChunk("active", []byte{0xf5}) + }, + }, + }) { + return + } + if slices.Contains(supportedMods, mockModuleName) { + ownerPlugins <- ownerPlugin + } } - } - }, nil) + }, + }) close(ownerPlugins) ctx, cancel := context.WithTimeout(context.Background(), time.Second) diff --git a/fdotest/client.go b/fdotest/client.go index da02b5f..a052758 100644 --- a/fdotest/client.go +++ b/fdotest/client.go @@ -46,17 +46,28 @@ const timeout = 10 * time.Second // device. type OwnerModulesFunc func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] +// Config provides options to +type Config struct { + // If state is nil, then an in-memory implementation will be used. This is + // useful for only testing service info modules. + State AllServerState + TPM tpm.TPM + Reuse bool + + DeviceModules map[string]serviceinfo.DeviceModule + OwnerModules OwnerModulesFunc + + CustomExpect func(*testing.T, error) +} + // RunClientTestSuite is used to test different implementations of server state // methods at an almost end-to-end level (transport is mocked). // -// If state is nil, then an in-memory implementation will be used. This is -// useful for only testing service info modules. -// //nolint:gocyclo -func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, deviceModules map[string]serviceinfo.DeviceModule, ownerModules OwnerModulesFunc, customExpect func(*testing.T, error)) { +func RunClientTestSuite(t *testing.T, conf Config) { slog.SetDefault(slog.New(slog.NewTextHandler(TestingLog(t), &slog.HandlerOptions{Level: slog.LevelDebug}))) - if state == nil { + if conf.State == nil { stateless, err := token.NewService() if err != nil { t.Fatal(err) @@ -67,17 +78,17 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device t.Fatal(err) } - state = struct { + conf.State = struct { *token.Service *memory.State }{stateless, inMemory} } transport := &Transport{ - Tokens: state, + Tokens: conf.State, DIResponder: &fdo.DIServer[custom.DeviceMfgInfo]{ - Session: state, - Vouchers: state, + Session: conf.State, + Vouchers: conf.State, SignDeviceCertificate: func(info *custom.DeviceMfgInfo) ([]*x509.Certificate, error) { // Validate device info csr := x509.CertificateRequest(info.CertInfo) @@ -86,7 +97,7 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device } // Sign CSR - key, chain, err := state.ManufacturerKey(info.KeyType) + key, chain, err := conf.State.ManufacturerKey(info.KeyType) if err != nil { var unsupportedErr fdo.ErrUnsupportedKeyType if errors.As(err, &unsupportedErr) { @@ -118,32 +129,32 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device chain = append([]*x509.Certificate{cert}, chain...) return chain, nil }, - AutoExtend: state, + AutoExtend: conf.State, RvInfo: func(context.Context, *fdo.Voucher) ([][]protocol.RvInstruction, error) { return [][]protocol.RvInstruction{}, nil }, }, TO0Responder: &fdo.TO0Server{ - Session: state, - RVBlobs: state, + Session: conf.State, + RVBlobs: conf.State, }, TO1Responder: &fdo.TO1Server{ - Session: state, - RVBlobs: state, + Session: conf.State, + RVBlobs: conf.State, }, TO2Responder: &fdo.TO2Server{ - Session: state, - Vouchers: state, - OwnerKeys: state, + Session: conf.State, + Vouchers: conf.State, + OwnerKeys: conf.State, RvInfo: func(context.Context, fdo.Voucher) ([][]protocol.RvInstruction, error) { return [][]protocol.RvInstruction{}, nil }, OwnerModules: func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { - if ownerModules == nil { + if conf.OwnerModules == nil { return func(yield func(string, serviceinfo.OwnerModule) bool) {} } - mods := ownerModules(ctx, replacementGUID, info, chain, devmod, supportedMods) + mods := conf.OwnerModules(ctx, replacementGUID, info, chain, devmod, supportedMods) return func(yield func(string, serviceinfo.OwnerModule) bool) { for modName, mod := range mods { if slices.Contains(supportedMods, modName) { @@ -154,14 +165,15 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device } } }, - VerifyVoucher: func(context.Context, fdo.Voucher) error { return nil }, + ReuseCredential: func(context.Context, fdo.Voucher) bool { return conf.Reuse }, + VerifyVoucher: func(context.Context, fdo.Voucher) error { return nil }, }, T: t, } to0 := &fdo.TO0Client{ - Vouchers: state, - OwnerKeys: state, + Vouchers: conf.State, + OwnerKeys: conf.State, } for _, table := range []struct { @@ -204,14 +216,14 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device } hmacSha256 := hmac.New(sha256.New, secret) hmacSha384 := hmac.New(sha512.New384, secret) - if tpmc != nil { + if conf.TPM != nil { secret = []byte("TPM2") var err error - hmacSha256, err = tpm.NewHmac(tpmc, crypto.SHA256) + hmacSha256, err = tpm.NewHmac(conf.TPM, crypto.SHA256) if err != nil { t.Fatal(err) } - hmacSha384, err = tpm.NewHmac(tpmc, crypto.SHA384) + hmacSha384, err = tpm.NewHmac(conf.TPM, crypto.SHA384) if err != nil { t.Fatal(err) } @@ -222,8 +234,8 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device switch table.keyType { case protocol.Secp256r1KeyType: var err error - if tpmc != nil { - key, err = tpm.GenerateECKey(tpmc, elliptic.P256()) + if conf.TPM != nil { + key, err = tpm.GenerateECKey(conf.TPM, elliptic.P256()) } else { key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) } @@ -233,8 +245,8 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device case protocol.Secp384r1KeyType: var err error - if tpmc != nil { - key, err = tpm.GenerateECKey(tpmc, elliptic.P384()) + if conf.TPM != nil { + key, err = tpm.GenerateECKey(conf.TPM, elliptic.P384()) } else { key, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) } @@ -244,8 +256,8 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device case protocol.Rsa2048RestrKeyType: var err error - if tpmc != nil { - key, err = tpm.GenerateRSAKey(tpmc, 2048) + if conf.TPM != nil { + key, err = tpm.GenerateRSAKey(conf.TPM, 2048) } else { key, err = rsa.GenerateKey(rand.Reader, 2048) } @@ -255,8 +267,8 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device case protocol.RsaPkcsKeyType: var err error - if tpmc != nil { - key, err = tpm.GenerateRSAKey(tpmc, 2048) // Simulator does not support RSA3072 + if conf.TPM != nil { + key, err = tpm.GenerateRSAKey(conf.TPM, 2048) // Simulator does not support RSA3072 } else { key, err = rsa.GenerateKey(rand.Reader, 3072) } @@ -266,9 +278,9 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device case protocol.RsaPssKeyType: var err error - if tpmc != nil { + if conf.TPM != nil { sigAlg = x509.SHA256WithRSAPSS - key, err = tpm.GenerateRSAPSSKey(tpmc, 2048) // Simulator does not support RSA3072 + key, err = tpm.GenerateRSAPSSKey(conf.TPM, 2048) // Simulator does not support RSA3072 } else { key, err = rsa.GenerateKey(rand.Reader, 3072) } @@ -327,7 +339,7 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device t.Fatal(err) } - if tpmc != nil { + if conf.TPM != nil { t.Logf("Credential: %s", tpm.DeviceCredential{ DeviceCredential: *cred, DeviceKey: tpm.FdoDeviceKey, @@ -397,13 +409,14 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device FileSep: ";", Bin: runtime.GOARCH, }, - KeyExchange: table.keyExchange, - CipherSuite: table.cipherSuite, + KeyExchange: table.keyExchange, + CipherSuite: table.cipherSuite, + AllowCredentialReuse: conf.Reuse, }) if err != nil { t.Fatal(err) } - if tpmc != nil { + if conf.TPM != nil { t.Logf("New credential: %s", tpm.DeviceCredential{ DeviceCredential: *cred, DeviceKey: tpm.FdoDeviceKey, @@ -440,13 +453,14 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device FileSep: ";", Bin: runtime.GOARCH, }, - KeyExchange: table.keyExchange, - CipherSuite: table.cipherSuite, + KeyExchange: table.keyExchange, + CipherSuite: table.cipherSuite, + AllowCredentialReuse: conf.Reuse, }) if err != nil { t.Fatal(err) } - if tpmc != nil { + if conf.TPM != nil { t.Logf("New credential: %s", tpm.DeviceCredential{ DeviceCredential: *cred, DeviceKey: tpm.FdoDeviceKey, @@ -482,19 +496,20 @@ func RunClientTestSuite(t *testing.T, state AllServerState, tpmc tpm.TPM, device FileSep: ";", Bin: runtime.GOARCH, }, - DeviceModules: deviceModules, - KeyExchange: table.keyExchange, - CipherSuite: table.cipherSuite, + DeviceModules: conf.DeviceModules, + KeyExchange: table.keyExchange, + CipherSuite: table.cipherSuite, + AllowCredentialReuse: conf.Reuse, }) - if customExpect != nil { - customExpect(t, err) + if conf.CustomExpect != nil { + conf.CustomExpect(t, err) if err != nil { return } } else if err != nil { t.Fatal(err) } - if tpmc != nil { + if conf.TPM != nil { t.Logf("New credential: %s", tpm.DeviceCredential{ DeviceCredential: *cred, DeviceKey: tpm.FdoDeviceKey, diff --git a/fsim/fsim_test.go b/fsim/fsim_test.go index 81419ce..bc208cb 100644 --- a/fsim/fsim_test.go +++ b/fsim/fsim_test.go @@ -59,65 +59,68 @@ func TestClientWithDataModules(t *testing.T) { errc := make(chan error) go func() { errc <- srv.Serve(lis) }() - fdotest.RunClientTestSuite(t, nil, nil, map[string]serviceinfo.DeviceModule{ - "fdo.download": &fsim.Download{ - CreateTemp: func() (*os.File, error) { - return os.CreateTemp("testdata", "fdo.download_*") - }, - NameToPath: func(name string) string { - return filepath.Join("testdata", "downloads", name) - }, - ErrorLog: fdotest.TestingLog(t), - }, - "fdo.upload": &fsim.Upload{FS: fstest.MapFS{ - "bigfile.test": &fstest.MapFile{ - Data: data, - Mode: 0777, - }, - }}, - "fdo.wget": &fsim.Wget{ - CreateTemp: func() (*os.File, error) { - return os.CreateTemp("testdata", "fdo.wget_*") + fdotest.RunClientTestSuite(t, fdotest.Config{ + DeviceModules: map[string]serviceinfo.DeviceModule{ + "fdo.download": &fsim.Download{ + CreateTemp: func() (*os.File, error) { + return os.CreateTemp("testdata", "fdo.download_*") + }, + NameToPath: func(name string) string { + return filepath.Join("testdata", "downloads", name) + }, + ErrorLog: fdotest.TestingLog(t), }, - NameToPath: func(name string) string { - return filepath.Join("testdata", "downloads", name) + "fdo.upload": &fsim.Upload{FS: fstest.MapFS{ + "bigfile.test": &fstest.MapFile{ + Data: data, + Mode: 0777, + }, + }}, + "fdo.wget": &fsim.Wget{ + CreateTemp: func() (*os.File, error) { + return os.CreateTemp("testdata", "fdo.wget_*") + }, + NameToPath: func(name string) string { + return filepath.Join("testdata", "downloads", name) + }, + Timeout: 10 * time.Second, }, - Timeout: 10 * time.Second, }, - }, func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { - return func(yield func(string, serviceinfo.OwnerModule) bool) { - if !yield("fdo.download", &fsim.DownloadContents[*bytes.Reader]{ - Name: "bigfile.test", - Contents: bytes.NewReader(data), - MustDownload: true, - }) { - return - } + OwnerModules: func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { + return func(yield func(string, serviceinfo.OwnerModule) bool) { + if !yield("fdo.download", &fsim.DownloadContents[*bytes.Reader]{ + Name: "bigfile.test", + Contents: bytes.NewReader(data), + MustDownload: true, + }) { + return + } - if !yield("fdo.upload", &fsim.UploadRequest{ - Dir: "testdata/uploads", - Name: "bigfile.test", - CreateTemp: func() (*os.File, error) { - return os.CreateTemp("testdata", "fdo.upload_*") - }, - }) { - return - } + if !yield("fdo.upload", &fsim.UploadRequest{ + Dir: "testdata/uploads", + Name: "bigfile.test", + CreateTemp: func() (*os.File, error) { + return os.CreateTemp("testdata", "fdo.upload_*") + }, + }) { + return + } - if !yield("fdo.wget", &fsim.WgetCommand{ - Name: "wget.test", - URL: &url.URL{ - Scheme: "http", - Host: lis.Addr().String(), - Path: "/file", - }, - Length: int64(len(data)), - Checksum: sum[:], - }) { - return + if !yield("fdo.wget", &fsim.WgetCommand{ + Name: "wget.test", + URL: &url.URL{ + Scheme: "http", + Host: lis.Addr().String(), + Path: "/file", + }, + Length: int64(len(data)), + Checksum: sum[:], + }) { + return + } } - } - }, nil) + }, + }) /// Validate contents downloadContents, err := os.ReadFile("testdata/downloads/bigfile.test") @@ -221,21 +224,24 @@ func TestClientWithMockDownloadOwner(t *testing.T) { }, } - fdotest.RunClientTestSuite(t, nil, nil, map[string]serviceinfo.DeviceModule{ - "fdo.download": &fsim.Download{ - CreateTemp: func() (*os.File, error) { - return os.CreateTemp("testdata", "fdo.download_*") - }, - NameToPath: func(name string) string { - return filepath.Join("testdata", "downloads", name) + fdotest.RunClientTestSuite(t, fdotest.Config{ + DeviceModules: map[string]serviceinfo.DeviceModule{ + "fdo.download": &fsim.Download{ + CreateTemp: func() (*os.File, error) { + return os.CreateTemp("testdata", "fdo.download_*") + }, + NameToPath: func(name string) string { + return filepath.Join("testdata", "downloads", name) + }, + ErrorLog: fdotest.TestingLog(t), }, - ErrorLog: fdotest.TestingLog(t), }, - }, func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { - return func(yield func(string, serviceinfo.OwnerModule) bool) { - yield("fdo.download", ownerModule) - } - }, nil) + OwnerModules: func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { + return func(yield func(string, serviceinfo.OwnerModule) bool) { + yield("fdo.download", ownerModule) + } + }, + }) } func TestClientWithCommandModule(t *testing.T) { @@ -246,40 +252,43 @@ func TestClientWithCommandModule(t *testing.T) { } runs := make(chan runData, 1000) - fdotest.RunClientTestSuite(t, nil, nil, map[string]serviceinfo.DeviceModule{ - "fdo.command": &fsim.Command{ - Timeout: 10 * time.Second, + fdotest.RunClientTestSuite(t, fdotest.Config{ + DeviceModules: map[string]serviceinfo.DeviceModule{ + "fdo.command": &fsim.Command{ + Timeout: 10 * time.Second, + }, }, - }, func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { - return func(yield func(string, serviceinfo.OwnerModule) bool) { - run := runData{exitChan: make(chan int, 1)} + OwnerModules: func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { + return func(yield func(string, serviceinfo.OwnerModule) bool) { + run := runData{exitChan: make(chan int, 1)} - if !yield("fdo.command", &fsim.RunCommand{ - Command: "date", - Args: []string{"--utc"}, - Stdout: struct { - io.Writer - io.Closer - }{ - Writer: &run.outbuf, - Closer: io.NopCloser(nil), - }, - Stderr: struct { - io.Writer - io.Closer - }{ - Writer: &run.errbuf, - Closer: io.NopCloser(nil), - }, - ExitChan: run.exitChan, - }) { - return - } - if slices.Contains(supportedMods, "fdo.command") { - runs <- run + if !yield("fdo.command", &fsim.RunCommand{ + Command: "date", + Args: []string{"--utc"}, + Stdout: struct { + io.Writer + io.Closer + }{ + Writer: &run.outbuf, + Closer: io.NopCloser(nil), + }, + Stderr: struct { + io.Writer + io.Closer + }{ + Writer: &run.errbuf, + Closer: io.NopCloser(nil), + }, + ExitChan: run.exitChan, + }) { + return + } + if slices.Contains(supportedMods, "fdo.command") { + runs <- run + } } - } - }, nil) + }, + }) close(runs) for run := range runs { diff --git a/server.go b/server.go index bd5838f..ca791c1 100644 --- a/server.go +++ b/server.go @@ -186,12 +186,19 @@ type TO2Server struct { Vouchers OwnerVoucherPersistentState OwnerKeys OwnerKeyPersistentState - // Rendezvous directives + // Choose the replacement rendezvous directives based on the current + // voucher of the onboarding device. RvInfo func(context.Context, Voucher) ([][]protocol.RvInstruction, error) - // Create an iterator of service info modules for a given device + // Create an iterator of service info modules for a given device. The + // iterator returns the name of the module and its implementation. OwnerModules func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, modules []string) iter.Seq2[string, serviceinfo.OwnerModule] + // ReuseCredential, if not nil, will be called to determine whether to + // apply the Credential Reuse Protocol based on the current voucher of an + // onboarding device. + ReuseCredential func(context.Context, Voucher) bool + // VerifyVoucher, if not nil, will be called before creating and responding // with a TO2.ProveOVHdr message. Any error will cause TO2 to fail with a // not found status code. diff --git a/sqlite/sqlite_test.go b/sqlite/sqlite_test.go index f9e614f..e2ef8df 100644 --- a/sqlite/sqlite_test.go +++ b/sqlite/sqlite_test.go @@ -25,7 +25,9 @@ func TestClient(t *testing.T) { state, cleanup := newDB(t) defer func() { _ = cleanup() }() - fdotest.RunClientTestSuite(t, state, nil, nil, nil, nil) + fdotest.RunClientTestSuite(t, fdotest.Config{ + State: state, + }) } func TestServerState(t *testing.T) { diff --git a/to2.go b/to2.go index 38c3816..eb43e48 100644 --- a/to2.go +++ b/to2.go @@ -19,6 +19,7 @@ import ( "iter" "log/slog" "math" + "reflect" "runtime" "strings" "sync" @@ -109,6 +110,11 @@ type TO2Config struct { // amounts of data, but choosing the best value depends on network // configuration (e.g. jumbo packets) and transport (overhead size). MaxServiceInfoSizeReceive uint16 + + // Allow for the Credential Reuse Protocol (Section 7) to be used. If not + // enabled, TO2 will fail with CredReuseErrCode (102) if reuse is + // attempted by the owner service. + AllowCredentialReuse bool } // TO2 runs the TO2 protocol and returns a DeviceCredential with replaced GUID, @@ -120,6 +126,9 @@ type TO2Config struct { // // It has the side effect of performing service info modules, which may include // actions such as downloading files. +// +// If the Credential Reuse protocol is allowed and occurs, then the returned +// device credential will be nil. func TO2(ctx context.Context, transport Transport, to1d *cose.Sign1[protocol.To1d, []byte], c TO2Config) (*DeviceCredential, error) { ctx = contextWithErrMsg(ctx) @@ -152,23 +161,27 @@ func TO2(ctx context.Context, transport Transport, to1d *cose.Sign1[protocol.To1 errorMsg(ctx, transport, err) return nil, err } - replacementOVH := &VoucherHeader{ - Version: originalOVH.Version, - GUID: partialOVH.GUID, - RvInfo: partialOVH.RvInfo, - DeviceInfo: originalOVH.DeviceInfo, - ManufacturerKey: partialOVH.ManufacturerKey, - CertChainHash: originalOVH.CertChainHash, - } - // Select the appropriate hash algorithm - ownerPubKey, err := partialOVH.ManufacturerKey.Public() - if err != nil { - return nil, fmt.Errorf("error parsing manufacturer public key type from incomplete replacement ownership voucher header: %w", err) - } - alg, err := hashAlgFor(c.Key.Public(), ownerPubKey) - if err != nil { - return nil, fmt.Errorf("error selecting the appropriate hash algorithm: %w", err) + // Select the appropriate hash algorithm for HMAC and public key hash + alg := c.Cred.PublicKeyHash.Algorithm + var replacementOVH *VoucherHeader + if partialOVH != nil { + nextOwnerPublicKey, err := partialOVH.ManufacturerKey.Public() + if err != nil { + return nil, fmt.Errorf("error parsing manufacturer public key type from incomplete replacement ownership voucher header: %w", err) + } + alg, err = hashAlgFor(c.Key.Public(), nextOwnerPublicKey) + if err != nil { + return nil, fmt.Errorf("error selecting the appropriate hash algorithm: %w", err) + } + replacementOVH = &VoucherHeader{ + Version: originalOVH.Version, + GUID: partialOVH.GUID, + RvInfo: partialOVH.RvInfo, + DeviceInfo: originalOVH.DeviceInfo, + ManufacturerKey: partialOVH.ManufacturerKey, + CertChainHash: originalOVH.CertChainHash, + } } // Prepare to send and receive service info, determining the transmit MTU @@ -193,6 +206,11 @@ func TO2(ctx context.Context, transport Transport, to1d *cose.Sign1[protocol.To1 return nil, err } + // If using the Credential Reuse protocol the device credential is not updated + if replacementOVH == nil { + return nil, nil + } + // Hash new initial owner public key and return replacement device // credential replacementKeyDigest := alg.HashFunc().New() @@ -810,11 +828,15 @@ func proveDevice(ctx context.Context, transport Transport, proveDeviceNonce prot captureErr(ctx, protocol.InvalidMessageErrCode, "") return protocol.Nonce{}, nil, fmt.Errorf("nonce in TO2.SetupDevice did not match nonce sent in TO2.ProveDevice") } - return setupDeviceNonce, &VoucherHeader{ + replacementOVH := &VoucherHeader{ GUID: setupDevice.Payload.Val.GUID, RvInfo: setupDevice.Payload.Val.RendezvousInfo, ManufacturerKey: setupDevice.Payload.Val.Owner2Key, - }, nil + } + if credReuse, err := reuseCredentials(ctx, replacementOVH, ownerPublicKey, c); err != nil || credReuse { + return setupDeviceNonce, nil, err + } + return setupDeviceNonce, replacementOVH, nil case protocol.ErrorMsgType: var errMsg protocol.ErrorMessage @@ -829,6 +851,24 @@ func proveDevice(ctx context.Context, transport Transport, proveDeviceNonce prot } } +func reuseCredentials(ctx context.Context, replacementOVH *VoucherHeader, ownerPublicKey crypto.PublicKey, c *TO2Config) (bool, error) { + replacementOwnerPublicKey, err := replacementOVH.ManufacturerKey.Public() + if err != nil { + captureErr(ctx, protocol.InvalidMessageErrCode, "") + return false, fmt.Errorf("owner key in TO2.SetupDevice could not be parsed: %w", err) + } + if replacementOVH.GUID != c.Cred.GUID || + !reflect.DeepEqual(replacementOVH.RvInfo, c.Cred.RvInfo) || + !replacementOwnerPublicKey.(interface{ Equal(crypto.PublicKey) bool }).Equal(ownerPublicKey) { + return false, nil + } + if !c.AllowCredentialReuse { + captureErr(ctx, protocol.CredReuseErrCode, "") + return false, fmt.Errorf("credential reuse is not enabled") + } + return true, nil +} + type deviceSetup struct { RendezvousInfo [][]protocol.RvInstruction // RendezvousInfo replacement GUID protocol.GUID // GUID replacement @@ -931,28 +971,31 @@ func (s *TO2Server) setupDevice(ctx context.Context, msg io.Reader) (*cose.Sign1 return nil, fmt.Errorf("error updating associated key exchange session: %w", err) } - // Get configured RV info - rvInfo, err := s.RvInfo(ctx, *ov) - if err != nil { - return nil, fmt.Errorf("error determining rendezvous info for device: %w", err) - } - if err := s.Session.SetRvInfo(ctx, rvInfo); err != nil { - return nil, fmt.Errorf("error storing rendezvous info for device: %w", err) - } - - // Generate a replacement GUID + // Get replacement GUID and rendezvous directives var replacementGUID protocol.GUID - if _, err := rand.Read(replacementGUID[:]); err != nil { - return nil, fmt.Errorf("error generating replacement GUID for device: %w", err) - } - if err := s.Session.SetReplacementGUID(ctx, replacementGUID); err != nil { - return nil, fmt.Errorf("error storing replacement GUID for device: %w", err) + var replacementRvInfo [][]protocol.RvInstruction + if s.ReuseCredential != nil && s.ReuseCredential(ctx, *ov) { + replacementGUID = ov.Header.Val.GUID + replacementRvInfo = ov.Header.Val.RvInfo + } else { + if _, err := rand.Read(replacementGUID[:]); err != nil { + return nil, fmt.Errorf("error generating replacement GUID for device: %w", err) + } + if err := s.Session.SetReplacementGUID(ctx, replacementGUID); err != nil { + return nil, fmt.Errorf("error storing replacement GUID for device: %w", err) + } + if replacementRvInfo, err = s.RvInfo(ctx, *ov); err != nil { + return nil, fmt.Errorf("error determining rendezvous info for device: %w", err) + } + if err := s.Session.SetRvInfo(ctx, replacementRvInfo); err != nil { + return nil, fmt.Errorf("error storing rendezvous info for device: %w", err) + } } // Respond with device setup s1 := cose.Sign1[deviceSetup, []byte]{ Payload: cbor.NewByteWrap(deviceSetup{ - RendezvousInfo: rvInfo, + RendezvousInfo: replacementRvInfo, GUID: replacementGUID, NonceTO2SetupDv: setupDeviceNonce, Owner2Key: *ownerPublicKey, @@ -985,14 +1028,18 @@ func sendReadyServiceInfo(ctx context.Context, transport Transport, alg protocol default: panic("only SHA256 and SHA384 are supported in FDO") } - replacementHmac, err := hmacHash(h, replacementOVH) - if err != nil { - return 0, fmt.Errorf("error computing HMAC of ownership voucher header: %w", err) + var hmac *protocol.Hash + if replacementOVH != nil { + replacementHmac, err := hmacHash(h, replacementOVH) + if err != nil { + return 0, fmt.Errorf("error computing HMAC of ownership voucher header: %w", err) + } + hmac = &replacementHmac } // Define request structure msg := deviceServiceInfoReady{ - Hmac: &replacementHmac, + Hmac: hmac, MaxOwnerServiceInfoSize: &c.MaxServiceInfoSizeReceive, } @@ -1042,14 +1089,6 @@ func (s *TO2Server) ownerServiceInfoReady(ctx context.Context, msg io.Reader) (* return nil, fmt.Errorf("error decoding TO2.DeviceServiceInfoReady request: %w", err) } - // Store new HMAC for voucher replacement - if deviceReady.Hmac == nil { - return nil, fmt.Errorf("device did not send a replacement voucher HMAC") - } - if err := s.Session.SetReplacementHmac(ctx, *deviceReady.Hmac); err != nil { - return nil, fmt.Errorf("error storing replacement voucher HMAC for device: %w", err) - } - // Set send MTU mtu := uint16(serviceinfo.DefaultMTU) if deviceReady.MaxOwnerServiceInfoSize != nil { @@ -1059,28 +1098,35 @@ func (s *TO2Server) ownerServiceInfoReady(ctx context.Context, msg io.Reader) (* return nil, fmt.Errorf("error storing max service info size to send to device: %w", err) } - // Get voucher and voucher replacement state - currentGUID, err := s.Session.GUID(ctx) + // Get current voucher + guid, err := s.Session.GUID(ctx) if err != nil { return nil, fmt.Errorf("error retrieving associated device GUID of proof session: %w", err) } - currentOV, err := s.Vouchers.Voucher(ctx, currentGUID) - if err != nil { - return nil, fmt.Errorf("error retrieving voucher for device %x: %w", currentGUID, err) - } - replacementGUID, err := s.Session.ReplacementGUID(ctx) + ov, err := s.Vouchers.Voucher(ctx, guid) if err != nil { - return nil, fmt.Errorf("error retrieving replacement GUID for device: %w", err) + return nil, fmt.Errorf("error retrieving voucher for device %x: %w", guid, err) } - info := currentOV.Header.Val.DeviceInfo + info := ov.Header.Val.DeviceInfo var deviceCertChain []*x509.Certificate - if currentOV.CertChain != nil { - deviceCertChain = make([]*x509.Certificate, len(*currentOV.CertChain)) - for i, cert := range *currentOV.CertChain { + if ov.CertChain != nil { + deviceCertChain = make([]*x509.Certificate, len(*ov.CertChain)) + for i, cert := range *ov.CertChain { deviceCertChain[i] = (*x509.Certificate)(cert) } } + // If not using the Credential Reuse Protocol (i.e. device sends an HMAC), + // then store the HMAC and get the replacement GUID + if deviceReady.Hmac != nil { + if err := s.Session.SetReplacementHmac(ctx, *deviceReady.Hmac); err != nil { + return nil, fmt.Errorf("error storing replacement voucher HMAC for device: %w", err) + } + if guid, err = s.Session.ReplacementGUID(ctx); err != nil { + return nil, fmt.Errorf("error retrieving replacement GUID for device: %w", err) + } + } + // Initialize service info modules s.plugins = make(map[string]plugin.Module) s.nextModule, s.stop = iter.Pull2(func() iter.Seq2[string, serviceinfo.OwnerModule] { @@ -1092,7 +1138,7 @@ func (s *TO2Server) ownerServiceInfoReady(ctx context.Context, msg io.Reader) (* if !yield("devmod", &devmod) { return } - ownerModules = s.OwnerModules(ctx, replacementGUID, info, deviceCertChain, devmod.Devmod, devmod.Modules) + ownerModules = s.OwnerModules(ctx, guid, info, deviceCertChain, devmod.Devmod, devmod.Modules) } ownerModules(func(moduleName string, mod serviceinfo.OwnerModule) bool { @@ -1525,7 +1571,16 @@ func (s *TO2Server) to2Done2(ctx context.Context, msg io.Reader) (*done2Msg, err return nil, fmt.Errorf("nonce from TO2.ProveDevice did not match TO2.Done") } - // Get voucher and voucher replacement state + // If the Credential Reuse Protocol is being used (replacement HMAC is not + // found), then immediately complete TO2 without replacing the voucher. + replacementHmac, err := s.Session.ReplacementHmac(ctx) + if errors.Is(err, ErrNotFound) { + return &done2Msg{NonceTO2SetupDv: setupDeviceNonce}, nil + } else if err != nil { + return nil, fmt.Errorf("error retrieving replacement Hmac for device: %w", err) + } + + // Get current and replacement voucher values currentGUID, err := s.Session.GUID(ctx) if err != nil { return nil, fmt.Errorf("error retrieving associated device GUID of proof session: %w", err) @@ -1534,7 +1589,6 @@ func (s *TO2Server) to2Done2(ctx context.Context, msg io.Reader) (*done2Msg, err if err != nil { return nil, fmt.Errorf("error retrieving voucher for device %x: %w", currentGUID, err) } - rvInfo, err := s.Session.RvInfo(ctx) if err != nil { return nil, fmt.Errorf("error retrieving rendezvous info for device: %w", err) @@ -1543,10 +1597,6 @@ func (s *TO2Server) to2Done2(ctx context.Context, msg io.Reader) (*done2Msg, err if err != nil { return nil, fmt.Errorf("error retrieving replacement GUID for device: %w", err) } - replacementHmac, err := s.Session.ReplacementHmac(ctx) - if err != nil { - return nil, fmt.Errorf("error retrieving replacement Hmac for device: %w", err) - } // Create and store a new voucher keyType := currentOV.Header.Val.ManufacturerKey.Type @@ -1574,7 +1624,5 @@ func (s *TO2Server) to2Done2(ctx context.Context, msg io.Reader) (*done2Msg, err } // Respond with nonce - return &done2Msg{ - NonceTO2SetupDv: setupDeviceNonce, - }, nil + return &done2Msg{NonceTO2SetupDv: setupDeviceNonce}, nil } diff --git a/tpm/tpm_test.go b/tpm/tpm_test.go index b2027ed..c78298c 100644 --- a/tpm/tpm_test.go +++ b/tpm/tpm_test.go @@ -22,5 +22,7 @@ func TestTPMDevice(t *testing.T) { } }() - fdotest.RunClientTestSuite(t, nil, sim, nil, nil, nil) + fdotest.RunClientTestSuite(t, fdotest.Config{ + TPM: sim, + }) }