diff --git a/cbor/cbor.go b/cbor/cbor.go index f3cdb2a..69ed614 100644 --- a/cbor/cbor.go +++ b/cbor/cbor.go @@ -1243,15 +1243,10 @@ func (e *Encoder) encodeTextOrBinary(rv reflect.Value) error { // Unaddressable arrays cannot be made into slices, so we must create a // slice and copy contents into it - slice := reflect.MakeSlice( - reflect.SliceOf(rv.Type().Elem()), - rv.Len(), - rv.Len(), - ) - if n := reflect.Copy(slice, rv); n != rv.Len() { + b = make([]byte, rv.Len()) + if n := reflect.Copy(reflect.ValueOf(b), rv); n != rv.Len() { panic("array contents were not fully copied into a slice for encoding") } - b = slice.Bytes() } info := u64Bytes(uint64(len(b))) diff --git a/di.go b/di.go index 49426b2..a0c8ffe 100644 --- a/di.go +++ b/di.go @@ -137,7 +137,7 @@ func appStart(ctx context.Context, transport Transport, info any) (*VoucherHeade // Make request typ, resp, err := transport.Send(ctx, protocol.DIAppStartMsgType, msg, nil) if err != nil { - return nil, fmt.Errorf("error sending DI.AppStart: %w", err) + return nil, fmt.Errorf("DI.AppStart: %w", err) } defer func() { _ = resp.Close() }() @@ -296,7 +296,7 @@ func setHmac(ctx context.Context, transport Transport, hmac hash.Hash, ovh *Vouc // Make request typ, resp, err := transport.Send(ctx, protocol.DISetHmacMsgType, msg, nil) if err != nil { - return fmt.Errorf("error sending DI.SetHMAC: %w", err) + return fmt.Errorf("DI.SetHMAC: %w", err) } defer func() { _ = resp.Close() }() diff --git a/examples/cmd/server.go b/examples/cmd/server.go index d8c2b68..2a2bf0e 100644 --- a/examples/cmd/server.go +++ b/examples/cmd/server.go @@ -10,15 +10,14 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/rsa" - "crypto/tls" "crypto/x509" "crypto/x509/pkix" - "database/sql" "encoding/hex" "encoding/pem" "errors" "flag" "fmt" + "io/fs" "iter" "log" "log/slog" @@ -110,11 +109,19 @@ func server() error { //nolint:gocyclo if dbPath == "" { return errors.New("db flag is required") } + _, dbStatErr := os.Stat(dbPath) state, err := sqlite.Open(dbPath, dbPass) if err != nil { return err } + // Generate keys only if the db wasn't already created + if errors.Is(dbStatErr, fs.ErrNotExist) { + if err := generateKeys(state); err != nil { + return err + } + } + // If printing owner public key, do so and exit if printOwnerPubKey != "" { return doPrintOwnerPubKey(state) @@ -125,36 +132,21 @@ func server() error { //nolint:gocyclo return doImportVoucher(state) } + // Normalize address flags useTLS = insecureTLS - - // RV Info - prot := protocol.RVProtHTTP - if useTLS { - prot = protocol.RVProtHTTPS - } - rvInfo := [][]protocol.RvInstruction{{{Variable: protocol.RVProtocol, Value: mustMarshal(prot)}}} if extAddr == "" { extAddr = addr } - host, portStr, err := net.SplitHostPort(extAddr) - if err != nil { - return fmt.Errorf("invalid external addr: %w", err) - } - if host == "" { - rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVIPAddress, Value: mustMarshal(net.IP{127, 0, 0, 1})}) - } else if hostIP := net.ParseIP(host); hostIP.To4() != nil || hostIP.To16() != nil { - rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVIPAddress, Value: mustMarshal(hostIP)}) + + // RV Info + var rvInfo [][]protocol.RvInstruction + if to0Addr != "" { + rvInfo, err = to0AddrToRvInfo() } else { - rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVDns, Value: mustMarshal(host)}) + rvInfo, err = extAddrToRvInfo() } - portNum, err := strconv.ParseUint(portStr, 10, 16) if err != nil { - return fmt.Errorf("invalid external port: %w", err) - } - port := uint16(portNum) - rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVDevPort, Value: mustMarshal(port)}) - if rvBypass { - rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVBypass}) + return err } // Test RVDelay by introducing a delay before TO1 @@ -162,7 +154,7 @@ func server() error { //nolint:gocyclo // Invoke TO0 client if a GUID is specified if to0GUID != "" { - return registerRvBlob(host, port, state) + return registerRvBlob(state) } // Invoke resale protocol if a GUID is specified @@ -173,6 +165,110 @@ func server() error { //nolint:gocyclo return serveHTTP(rvInfo, state) } +func generateKeys(state *sqlite.DB) error { //nolint:gocyclo + // Generate manufacturing component keys + rsa2048MfgKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + rsa3072MfgKey, err := rsa.GenerateKey(rand.Reader, 3072) + if err != nil { + return err + } + ec256MfgKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return err + } + ec384MfgKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + return err + } + generateCA := func(key crypto.Signer) ([]*x509.Certificate, error) { + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "Test CA"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(30 * 365 * 24 * time.Hour), + BasicConstraintsValid: true, + IsCA: true, + } + der, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) + if err != nil { + return nil, err + } + cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, err + } + return []*x509.Certificate{cert}, nil + } + rsa2048Chain, err := generateCA(rsa2048MfgKey) + if err != nil { + return err + } + rsa3072Chain, err := generateCA(rsa3072MfgKey) + if err != nil { + return err + } + ec256Chain, err := generateCA(ec256MfgKey) + if err != nil { + return err + } + ec384Chain, err := generateCA(ec384MfgKey) + if err != nil { + return err + } + if err := state.AddManufacturerKey(protocol.Rsa2048RestrKeyType, rsa2048MfgKey, rsa2048Chain); err != nil { + return err + } + if err := state.AddManufacturerKey(protocol.RsaPkcsKeyType, rsa3072MfgKey, rsa3072Chain); err != nil { + return err + } + if err := state.AddManufacturerKey(protocol.RsaPssKeyType, rsa3072MfgKey, rsa3072Chain); err != nil { + return err + } + if err := state.AddManufacturerKey(protocol.Secp256r1KeyType, ec256MfgKey, ec256Chain); err != nil { + return err + } + if err := state.AddManufacturerKey(protocol.Secp384r1KeyType, ec384MfgKey, ec384Chain); err != nil { + return err + } + + // Generate owner keys + rsa2048OwnerKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + rsa3072OwnerKey, err := rsa.GenerateKey(rand.Reader, 3072) + if err != nil { + return err + } + ec256OwnerKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return err + } + ec384OwnerKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + return err + } + if err := state.AddOwnerKey(protocol.Rsa2048RestrKeyType, rsa2048OwnerKey, nil); err != nil { + return err + } + if err := state.AddOwnerKey(protocol.RsaPkcsKeyType, rsa3072OwnerKey, nil); err != nil { + return err + } + if err := state.AddOwnerKey(protocol.RsaPssKeyType, rsa3072OwnerKey, nil); err != nil { + return err + } + if err := state.AddOwnerKey(protocol.Secp256r1KeyType, ec256OwnerKey, nil); err != nil { + return err + } + if err := state.AddOwnerKey(protocol.Secp384r1KeyType, ec384OwnerKey, nil); err != nil { + return err + } + return nil +} + func serveHTTP(rvInfo [][]protocol.RvInstruction, state *sqlite.DB) error { // Create FDO responder handler, err := newHandler(rvInfo, state) @@ -197,15 +293,7 @@ func serveHTTP(rvInfo [][]protocol.RvInstruction, state *sqlite.DB) error { slog.Info("Listening", "local", lis.Addr().String(), "external", extAddr) if useTLS { - cert, err := tlsCert(state.DB()) - if err != nil { - return err - } - srv.TLSConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - Certificates: []tls.Certificate{*cert}, - } - return srv.ServeTLS(lis, "", "") + return serveTLS(lis, srv, state.DB()) } return srv.Serve(lis) } @@ -264,7 +352,75 @@ func doImportVoucher(state *sqlite.DB) error { return state.AddVoucher(context.Background(), &ov) } -func registerRvBlob(host string, port uint16, state *sqlite.DB) error { +func to0AddrToRvInfo() ([][]protocol.RvInstruction, error) { + url, err := url.Parse(to0Addr) + if err != nil { + return nil, fmt.Errorf("cannot parse TO0 addr: %w", err) + } + prot := protocol.RVProtHTTP + if url.Scheme == "https" { + prot = protocol.RVProtHTTPS + } + rvInfo := [][]protocol.RvInstruction{{{Variable: protocol.RVProtocol, Value: mustMarshal(prot)}}} + host, portStr, err := net.SplitHostPort(url.Host) + if err != nil { + host = url.Host + } + if portStr == "" { + portStr = "80" + if url.Scheme == "https" { + portStr = "443" + } + } + if host == "" { + rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVIPAddress, Value: mustMarshal(net.IP{127, 0, 0, 1})}) + } else if hostIP := net.ParseIP(host); hostIP.To4() != nil || hostIP.To16() != nil { + rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVIPAddress, Value: mustMarshal(hostIP)}) + } else { + rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVDns, Value: mustMarshal(host)}) + } + portNum, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid TO0 port: %w", err) + } + port := uint16(portNum) + rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVDevPort, Value: mustMarshal(port)}) + if rvBypass { + rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVBypass}) + } + return rvInfo, nil +} + +func extAddrToRvInfo() ([][]protocol.RvInstruction, error) { + prot := protocol.RVProtHTTP + if useTLS { + prot = protocol.RVProtHTTPS + } + rvInfo := [][]protocol.RvInstruction{{{Variable: protocol.RVProtocol, Value: mustMarshal(prot)}}} + host, portStr, err := net.SplitHostPort(extAddr) + if err != nil { + return nil, fmt.Errorf("invalid external addr: %w", err) + } + if host == "" { + rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVIPAddress, Value: mustMarshal(net.IP{127, 0, 0, 1})}) + } else if hostIP := net.ParseIP(host); hostIP.To4() != nil || hostIP.To16() != nil { + rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVIPAddress, Value: mustMarshal(hostIP)}) + } else { + rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVDns, Value: mustMarshal(host)}) + } + portNum, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid external port: %w", err) + } + port := uint16(portNum) + rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVDevPort, Value: mustMarshal(port)}) + if rvBypass { + rvInfo[0] = append(rvInfo[0], protocol.RvInstruction{Variable: protocol.RVBypass}) + } + return rvInfo, nil +} + +func registerRvBlob(state *sqlite.DB) error { if to0Addr == "" { return fmt.Errorf("to0-guid depends on to0 flag being set") } @@ -280,11 +436,23 @@ func registerRvBlob(host string, port uint16, state *sqlite.DB) error { var guid protocol.GUID copy(guid[:], guidBytes) + // Construct TO2 addr proto := protocol.HTTPTransport if useTLS { proto = protocol.HTTPSTransport } - + host, portStr, err := net.SplitHostPort(extAddr) + if err != nil { + return fmt.Errorf("invalid external addr: %w", err) + } + if host == "" { + host = "localhost" + } + portNum, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return fmt.Errorf("invalid external port: %w", err) + } + port := uint16(portNum) to2Addrs := []protocol.RvTO2Addr{ { DNSAddress: &host, @@ -292,6 +460,8 @@ func registerRvBlob(host string, port uint16, state *sqlite.DB) error { TransportProtocol: proto, }, } + + // Register RV blob with RV server refresh, err := (&fdo.TO0Client{ Vouchers: state, OwnerKeys: state, @@ -359,109 +529,7 @@ func mustMarshal(v any) []byte { return data } -//nolint:gocyclo func newHandler(rvInfo [][]protocol.RvInstruction, state *sqlite.DB) (*transport.Handler, error) { - // Generate manufacturing component keys - rsa2048MfgKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err - } - rsa3072MfgKey, err := rsa.GenerateKey(rand.Reader, 3072) - if err != nil { - return nil, err - } - ec256MfgKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return nil, err - } - ec384MfgKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) - if err != nil { - return nil, err - } - generateCA := func(key crypto.Signer) ([]*x509.Certificate, error) { - template := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{CommonName: "Test CA"}, - NotBefore: time.Now(), - NotAfter: time.Now().Add(30 * 365 * 24 * time.Hour), - BasicConstraintsValid: true, - IsCA: true, - } - der, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) - if err != nil { - return nil, err - } - cert, err := x509.ParseCertificate(der) - if err != nil { - return nil, err - } - return []*x509.Certificate{cert}, nil - } - rsa2048Chain, err := generateCA(rsa2048MfgKey) - if err != nil { - return nil, err - } - rsa3072Chain, err := generateCA(rsa3072MfgKey) - if err != nil { - return nil, err - } - ec256Chain, err := generateCA(ec256MfgKey) - if err != nil { - return nil, err - } - ec384Chain, err := generateCA(ec384MfgKey) - if err != nil { - return nil, err - } - if err := state.AddManufacturerKey(protocol.Rsa2048RestrKeyType, rsa2048MfgKey, rsa2048Chain); err != nil { - return nil, err - } - if err := state.AddManufacturerKey(protocol.RsaPkcsKeyType, rsa3072MfgKey, rsa3072Chain); err != nil { - return nil, err - } - if err := state.AddManufacturerKey(protocol.RsaPssKeyType, rsa3072MfgKey, rsa3072Chain); err != nil { - return nil, err - } - if err := state.AddManufacturerKey(protocol.Secp256r1KeyType, ec256MfgKey, ec256Chain); err != nil { - return nil, err - } - if err := state.AddManufacturerKey(protocol.Secp384r1KeyType, ec384MfgKey, ec384Chain); err != nil { - return nil, err - } - - // Generate owner keys - rsa2048OwnerKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err - } - rsa3072OwnerKey, err := rsa.GenerateKey(rand.Reader, 3072) - if err != nil { - return nil, err - } - ec256OwnerKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return nil, err - } - ec384OwnerKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) - if err != nil { - return nil, err - } - if err := state.AddOwnerKey(protocol.Rsa2048RestrKeyType, rsa2048OwnerKey, nil); err != nil { - return nil, err - } - if err := state.AddOwnerKey(protocol.RsaPkcsKeyType, rsa3072OwnerKey, nil); err != nil { - return nil, err - } - if err := state.AddOwnerKey(protocol.RsaPssKeyType, rsa3072OwnerKey, nil); err != nil { - return nil, err - } - if err := state.AddOwnerKey(protocol.Secp256r1KeyType, ec256OwnerKey, nil); err != nil { - return nil, err - } - if err := state.AddOwnerKey(protocol.Secp384r1KeyType, ec384OwnerKey, nil); err != nil { - return nil, err - } - // Auto-register RV blob so that TO1 can be tested unless a TO0 address is // given or RV bypass is set var autoTO0 fdo.AutoTO0 @@ -526,6 +594,7 @@ func newHandler(rvInfo [][]protocol.RvInstruction, state *sqlite.DB) (*transport }, nil } +//nolint:gocyclo func ownerModules(ctx context.Context, guid protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, modules []string) iter.Seq2[string, serviceinfo.OwnerModule] { return func(yield func(string, serviceinfo.OwnerModule) bool) { if slices.Contains(modules, "fdo.download") { @@ -584,67 +653,3 @@ func ownerModules(ctx context.Context, guid protocol.GUID, info string, chain [] } } } - -func tlsCert(db *sql.DB) (*tls.Certificate, error) { - // Ensure that the https table exists - if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS https - ( cert BLOB NOT NULL - , key BLOB NOT NULL - )`); err != nil { - return nil, err - } - - // Load a TLS cert and key from the database - row := db.QueryRow("SELECT cert, key FROM https LIMIT 1") - var certDer, keyDer []byte - if err := row.Scan(&certDer, &keyDer); err != nil && !errors.Is(err, sql.ErrNoRows) { - return nil, err - } - if len(keyDer) > 0 { - key, err := x509.ParsePKCS8PrivateKey(keyDer) - if err != nil { - return nil, fmt.Errorf("bad HTTPS key stored: %w", err) - } - return &tls.Certificate{ - Certificate: [][]byte{certDer}, - PrivateKey: key, - }, nil - } - - // Generate a new self-signed TLS CA - tlsKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) - if err != nil { - return nil, err - } - caTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{CommonName: "Test CA"}, - NotBefore: time.Now(), - NotAfter: time.Now().Add(30 * 365 * 24 * time.Hour), - BasicConstraintsValid: true, - IsCA: true, - } - caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, tlsKey.Public(), tlsKey) - if err != nil { - return nil, err - } - tlsCA, err := x509.ParseCertificate(caDER) - if err != nil { - return nil, err - } - - // Store TLS cert and key to the database - keyDER, err := x509.MarshalPKCS8PrivateKey(tlsKey) - if err != nil { - return nil, err - } - if _, err := db.Exec("INSERT INTO https (cert, key) VALUES (?, ?)", caDER, keyDER); err != nil { - return nil, err - } - - // Use CA to serve TLS - return &tls.Certificate{ - Certificate: [][]byte{tlsCA.Raw}, - PrivateKey: tlsKey, - }, nil -} diff --git a/examples/cmd/tls.go b/examples/cmd/tls.go index 2b5da35..8250cc2 100644 --- a/examples/cmd/tls.go +++ b/examples/cmd/tls.go @@ -1,10 +1,21 @@ // SPDX-FileCopyrightText: (C) 2024 Intel Corporation // SPDX-License-Identifier: Apache 2.0 +//go:build !tinygo + package main import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "database/sql" + "errors" + "fmt" + "math/big" "net" net_http "net/http" "time" @@ -39,3 +50,79 @@ func tlsTransport(baseURL string, conf *tls.Config) fdo.Transport { }}, } } + +func serveTLS(lis net.Listener, srv *net_http.Server, db *sql.DB) error { + cert, err := tlsCert(db) + if err != nil { + return err + } + srv.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{*cert}, + } + return srv.ServeTLS(lis, "", "") +} + +func tlsCert(db *sql.DB) (*tls.Certificate, error) { + // Ensure that the https table exists + if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS https + ( cert BLOB NOT NULL + , key BLOB NOT NULL + )`); err != nil { + return nil, err + } + + // Load a TLS cert and key from the database + row := db.QueryRow("SELECT cert, key FROM https LIMIT 1") + var certDer, keyDer []byte + if err := row.Scan(&certDer, &keyDer); err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + if len(keyDer) > 0 { + key, err := x509.ParsePKCS8PrivateKey(keyDer) + if err != nil { + return nil, fmt.Errorf("bad HTTPS key stored: %w", err) + } + return &tls.Certificate{ + Certificate: [][]byte{certDer}, + PrivateKey: key, + }, nil + } + + // Generate a new self-signed TLS CA + tlsKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + return nil, err + } + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "Test CA"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(30 * 365 * 24 * time.Hour), + BasicConstraintsValid: true, + IsCA: true, + } + caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, tlsKey.Public(), tlsKey) + if err != nil { + return nil, err + } + tlsCA, err := x509.ParseCertificate(caDER) + if err != nil { + return nil, err + } + + // Store TLS cert and key to the database + keyDER, err := x509.MarshalPKCS8PrivateKey(tlsKey) + if err != nil { + return nil, err + } + if _, err := db.Exec("INSERT INTO https (cert, key) VALUES (?, ?)", caDER, keyDER); err != nil { + return nil, err + } + + // Use CA to serve TLS + return &tls.Certificate{ + Certificate: [][]byte{tlsCA.Raw}, + PrivateKey: tlsKey, + }, nil +} diff --git a/examples/cmd/tls_tinygo.go b/examples/cmd/tls_tinygo.go new file mode 100644 index 0000000..8a196a3 --- /dev/null +++ b/examples/cmd/tls_tinygo.go @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 + +//go:build tinygo + +package main + +import ( + "crypto/tls" + "database/sql" + "net" + net_http "net/http" + "strings" + + "github.com/fido-device-onboard/go-fdo" + "github.com/fido-device-onboard/go-fdo/http" +) + +var insecureTLS bool + +func tlsTransport(baseURL string, conf *tls.Config) fdo.Transport { + if conf != nil || strings.HasPrefix(baseURL, "https") { + panic("TLS unsupported by TinyGo") + } + return &http.Transport{ + BaseURL: baseURL, + Client: &net_http.Client{Transport: net_http.DefaultTransport}, + } +} + +func serveTLS(lis net.Listener, srv *net_http.Server, db *sql.DB) error { + panic("TLS unsupported by TinyGo") +} diff --git a/examples/go.mod b/examples/go.mod index 108a3eb..eadd5bf 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -17,6 +17,7 @@ require ( github.com/fido-device-onboard/go-fdo/tpm v0.0.0-00010101000000-000000000000 github.com/google/go-tpm v0.9.2-0.20240920144513-364d5f2f78b9 github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba + github.com/syumai/workers v0.27.0 hermannm.dev/devlog v0.4.1 ) diff --git a/examples/go.sum b/examples/go.sum index 7cad18f..592e7dc 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -47,6 +47,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/syumai/workers v0.27.0 h1:Y3J4KtlYveAaXXQYnoE/JjiQzKY0a3/O2GQDHApC55Y= +github.com/syumai/workers v0.27.0/go.mod h1:ZnqmdiHNBrbxOLrZ/HJ5jzHy6af9cmiNZk10R9NrIEA= github.com/tetratelabs/wazero v1.8.1 h1:NrcgVbWfkWvVc4UtT4LRLDf91PsOzDzefMdwhLfA550= github.com/tetratelabs/wazero v1.8.1/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= diff --git a/examples/wasm/.gitignore b/examples/wasm/.gitignore new file mode 100644 index 0000000..524856e --- /dev/null +++ b/examples/wasm/.gitignore @@ -0,0 +1,3 @@ +/build/ +node_modules/ +.wrangler/ diff --git a/examples/wasm/README.md b/examples/wasm/README.md new file mode 100644 index 0000000..69ce670 --- /dev/null +++ b/examples/wasm/README.md @@ -0,0 +1,80 @@ +# Rendzvous Service on Cloudflare Workers + +## Prerequisites + +Download TinyGo from your OS package manager and Wrangler from npm. + +## Deploy + +Create a Cloudflare Worker with the name "rv" and start a new repo with the configuration in `wrangler.toml`. + +Create a D1 database instance with name "rv" and update `wrangler.toml` with its UUID. + +Execute the included schema.sql setup. + +```console +wrangler d1 execute rv --remote --file=./schema.sql +``` + +Then deploy the application. + +```console +wrangler deploy +``` + +## Usage + +Add users by email address. + +```console +wrangler d1 execute rv --remote --command 'INSERT INTO trusted_emails (email) VALUES ("user@example.com")' +``` + +Add owner keys, connected to user accounts for auditability. + +```console +wrangler d1 execute rv --remote --command "INSERT INTO trusted_owners (email, pkix) VALUES ('user@example.com', UNHEX('$OWNER_KEY'))" +``` + +### Example + +Get the owner key from the example application to add to RV server. + +```bash +OWNER_KEY=$(go run ./examples/cmd server -db db.test -print-owner-public SECP384R1 | head -n -1 | tail -n +2 | tr -d '\n' | base64 -d | xxd -p -c 0) +``` + +Initialize device credentials. + +```console +$ go run ./examples/cmd server -db db.test -http 127.0.0.1:9999 -to0 https://rv.${SUBDOMAIN}.workers.dev +[2024-11-01 00:00:00] INFO: Listening + local: 127.0.0.1:9999 + external: 127.0.0.1:9999 +``` + +```console +$ go run ./examples/cmd client -di http://127.0.0.1:9999 +$ go run ./examples/cmd client -print +blobcred[ + ... + GUID d21d841a3f54f4e89a60ed9b9779e9e8 + ... +] +$ go run ./examples/cmd client -rv-only +``` + +Register RV blob with RV server. + +```console +$ go run ./examples/cmd server -db db.test -http 127.0.0.1:9999 -to0 https://rv.${SUBDOMAIN}.workers.dev -to0-guid d21d841a3f54f4e89a60ed9b9779e9e8 +[2024-11-01 00:00:00] INFO: RV blob registered + ttl: 168h0m0s +``` + +Transfer ownership using the Cloudflare RV service and local owner service. + +```console +$ go run ./examples/cmd client +Success +``` diff --git a/examples/wasm/main.go b/examples/wasm/main.go new file mode 100644 index 0000000..a0f98a4 --- /dev/null +++ b/examples/wasm/main.go @@ -0,0 +1,106 @@ +// SPDX-FileCopyrightText: (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 + +//go:build tinygo + +// Package main implements a Rendezvous Server which can be compiled with +// TinyGo and run on Cloudflare Workers within the free tier (under reasonable +// load). +package main + +import ( + "context" + "crypto/x509" + "database/sql" + "errors" + "fmt" + "log/slog" + "net/http" + "os" + "time" + + "github.com/syumai/workers" + "github.com/syumai/workers/cloudflare/cron" + _ "github.com/syumai/workers/cloudflare/d1" + + "github.com/fido-device-onboard/go-fdo" + fdo_http "github.com/fido-device-onboard/go-fdo/http" + "github.com/fido-device-onboard/go-fdo/sqlite" +) + +const oneWeekInSeconds uint32 = 7 * 24 * 60 * 60 + +func main() { + slog.SetLogLoggerLevel(slog.LevelDebug) + + db, err := sql.Open("d1", "RendezvousDB") + if err != nil { + slog.Error("d1 connect", "error", err) + os.Exit(1) + } + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + slog.Error("d1 pragma", "error", err) + os.Exit(1) + } + + // Handle FDO protocol endpoint + state := sqlite.New(db) + handler := http.NewServeMux() + + // If building with Go instead of TinyGo, use: + //handler.Handle("POST /fdo/101/msg/{msg}", &fdo_http.Handler{ + handler.Handle("/fdo/101/msg/", &fdo_http.Handler{ + Tokens: state, + TO0Responder: &fdo.TO0Server{ + Session: state, + RVBlobs: state, + AcceptVoucher: func(ctx context.Context, ov fdo.Voucher) (accept bool, err error) { + owner, err := ov.OwnerPublicKey() + if err != nil { + return false, fmt.Errorf("error getting voucher owner key: %w", err) + } + der, err := x509.MarshalPKIXPublicKey(owner) + if err != nil { + return false, fmt.Errorf("error marshaling voucher owner key: %w", err) + } + return trustedOwner(ctx, db, der) + }, + NegotiateTTL: func(requestedSeconds uint32, ov fdo.Voucher) (waitSeconds uint32) { + return min(requestedSeconds, oneWeekInSeconds) + }, + }, + TO1Responder: &fdo.TO1Server{ + Session: state, + RVBlobs: state, + }, + }) + + // Schedule a daily task to cleanup expired RV blobs + cron.ScheduleTaskNonBlock(func(ctx context.Context) error { + e, err := cron.NewEvent(ctx) + if err != nil { + return err + } + return removeExpiredBlobs(ctx, db, e.ScheduledTime) + }) + + workers.Serve(handler) +} + +func trustedOwner(ctx context.Context, db *sql.DB, pkixKey []byte) (bool, error) { + var email string + row := db.QueryRowContext(ctx, `SELECT email FROM trusted_owners WHERE pkix = ?`, pkixKey) + if err := row.Scan(&email); errors.Is(err, sql.ErrNoRows) { + return false, nil + } else if err != nil { + return false, err + } + slog.Info("accepting voucher", "user", email) + + return true, nil +} + +func removeExpiredBlobs(ctx context.Context, db *sql.DB, nowish time.Time) error { + _, err := db.ExecContext(ctx, `DELETE FROM rv_blobs WHERE exp < ?`, nowish.Unix()) + return err +} diff --git a/examples/wasm/schema.sql b/examples/wasm/schema.sql new file mode 100644 index 0000000..fdbc639 --- /dev/null +++ b/examples/wasm/schema.sql @@ -0,0 +1,35 @@ +PRAGMA foreign_keys = ON; +CREATE TABLE IF NOT EXISTS secrets + ( type TEXT NOT NULL + , secret BLOB NOT NULL + ); +CREATE TABLE IF NOT EXISTS sessions + ( id BLOB PRIMARY KEY + , protocol INTEGER NOT NULL + ); +CREATE TABLE IF NOT EXISTS to0_sessions + ( session BLOB UNIQUE NOT NULL + , nonce BLOB + , FOREIGN KEY(session) REFERENCES sessions(id) ON DELETE CASCADE + ); +CREATE TABLE IF NOT EXISTS to1_sessions + ( session BLOB UNIQUE NOT NULL + , nonce BLOB + , alg INTEGER + , FOREIGN KEY(session) REFERENCES sessions(id) ON DELETE CASCADE + ); +CREATE TABLE IF NOT EXISTS rv_blobs + ( guid BLOB PRIMARY KEY + , rv BLOB NOT NULL + , voucher BLOB NOT NULL + , exp INTEGER NOT NULL + ); +CREATE INDEX IF NOT EXISTS rv_blob_exp ON rv_blobs(exp ASC); +CREATE TABLE IF NOT EXISTS trusted_emails + ( email TEXT PRIMARY KEY + ); +CREATE TABLE IF NOT EXISTS trusted_owners + ( pkix BLOB PRIMARY KEY + , email TEXT NOT NULL + , FOREIGN KEY(email) REFERENCES trusted_emails(email) ON DELETE CASCADE + ); diff --git a/examples/wasm/wrangler.toml b/examples/wasm/wrangler.toml new file mode 100644 index 0000000..be224d6 --- /dev/null +++ b/examples/wasm/wrangler.toml @@ -0,0 +1,29 @@ +name = "rv" +main = "./build/worker.mjs" +compatibility_date = "2024-10-22" + +[build] +command = """ +go run github.com/syumai/workers/cmd/workers-assets-gen@v0.26.3 -mode tinygo && +tinygo build -o ./build/app.wasm -target wasm -gc=leaking -no-debug -opt=2 ./main.go +""" + +# For testing locally, you can use Go instead of TinyGo for better debug support, such as +# backtraces for panics. +# +# [build] +# command = """ +# go run github.com/syumai/workers/cmd/workers-assets-gen@v0.26.3 -mode go && +# GOOS=js GOARCH=wasm go build -ldflags '-s -w' -o ./build/app.wasm ./main.go +# """ + +[triggers] +crons = [ "0 3 * * *"] # 3AM daily + +[observability] +enabled = true + +[[ d1_databases ]] +binding = "RendezvousDB" +database_name = "rv" +database_id = "COPY_YOUR_UUID_HERE" diff --git a/fdotest/client.go b/fdotest/client.go index 525f47c..8aa4a43 100644 --- a/fdotest/client.go +++ b/fdotest/client.go @@ -56,9 +56,15 @@ type Config struct { // the device credential. Otherwise the blob package will be used. NewCredential func(protocol.KeyType) (hmacSha256, hmacSha384 hash.Hash, key crypto.Signer, toDeviceCred func(fdo.DeviceCredential) any) + // If NewTransport is non-nil, then it will be used in place of + // fdo.Transport. + NewTransport func(t *testing.T, tokens protocol.TokenService, di, to0, to1, to2 protocol.Responder) fdo.Transport + // Use the Credential Reuse Protocol Reuse bool + NoDebug bool + DeviceModules map[string]serviceinfo.DeviceModule OwnerModules OwnerModulesFunc @@ -70,7 +76,11 @@ type Config struct { // //nolint:gocyclo func RunClientTestSuite(t *testing.T, conf Config) { - slog.SetDefault(slog.New(slog.NewTextHandler(TestingLog(t), &slog.HandlerOptions{Level: slog.LevelDebug}))) + level := slog.LevelDebug + if conf.NoDebug { + level = slog.LevelInfo + } + slog.SetDefault(slog.New(slog.NewTextHandler(TestingLog(t), &slog.HandlerOptions{Level: level}))) if conf.State == nil { stateless, err := token.NewService() @@ -89,91 +99,99 @@ func RunClientTestSuite(t *testing.T, conf Config) { }{stateless, inMemory} } - transport := &Transport{ - Tokens: conf.State, - DIResponder: &fdo.DIServer[custom.DeviceMfgInfo]{ - Session: conf.State, - Vouchers: conf.State, - SignDeviceCertificate: func(info *custom.DeviceMfgInfo) ([]*x509.Certificate, error) { - // Validate device info - csr := x509.CertificateRequest(info.CertInfo) - if err := csr.CheckSignature(); err != nil { - return nil, fmt.Errorf("invalid CSR: %w", err) - } + diResponder := &fdo.DIServer[custom.DeviceMfgInfo]{ + Session: conf.State, + Vouchers: conf.State, + SignDeviceCertificate: func(info *custom.DeviceMfgInfo) ([]*x509.Certificate, error) { + // Validate device info + csr := x509.CertificateRequest(info.CertInfo) + if err := csr.CheckSignature(); err != nil { + return nil, fmt.Errorf("invalid CSR: %w", err) + } - // Sign CSR - key, chain, err := conf.State.ManufacturerKey(info.KeyType) - if err != nil { - var unsupportedErr fdo.ErrUnsupportedKeyType - if errors.As(err, &unsupportedErr) { - return nil, unsupportedErr - } - return nil, fmt.Errorf("error retrieving manufacturer key [type=%s]: %w", info.KeyType, err) + // Sign CSR + key, chain, err := conf.State.ManufacturerKey(info.KeyType) + if err != nil { + var unsupportedErr fdo.ErrUnsupportedKeyType + if errors.As(err, &unsupportedErr) { + return nil, unsupportedErr } - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - return nil, fmt.Errorf("error generating certificate serial number: %w", err) - } - template := &x509.Certificate{ - SerialNumber: serialNumber, - Issuer: chain[0].Subject, - Subject: csr.Subject, - NotBefore: time.Now(), - NotAfter: time.Now().Add(30 * 360 * 24 * time.Hour), // Matches Java impl - KeyUsage: x509.KeyUsageDigitalSignature, - } - der, err := x509.CreateCertificate(rand.Reader, template, chain[0], csr.PublicKey, key) - if err != nil { - return nil, fmt.Errorf("error signing CSR: %w", err) - } - cert, err := x509.ParseCertificate(der) - if err != nil { - return nil, fmt.Errorf("error parsing signed device cert: %w", err) - } - chain = append([]*x509.Certificate{cert}, chain...) - return chain, nil - }, - AutoExtend: conf.State, - RvInfo: func(context.Context, *fdo.Voucher) ([][]protocol.RvInstruction, error) { - return [][]protocol.RvInstruction{}, nil - }, + return nil, fmt.Errorf("error retrieving manufacturer key [type=%s]: %w", info.KeyType, err) + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, fmt.Errorf("error generating certificate serial number: %w", err) + } + template := &x509.Certificate{ + SerialNumber: serialNumber, + Issuer: chain[0].Subject, + Subject: csr.Subject, + NotBefore: time.Now(), + NotAfter: time.Now().Add(30 * 360 * 24 * time.Hour), // Matches Java impl + KeyUsage: x509.KeyUsageDigitalSignature, + } + der, err := x509.CreateCertificate(rand.Reader, template, chain[0], csr.PublicKey, key) + if err != nil { + return nil, fmt.Errorf("error signing CSR: %w", err) + } + cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, fmt.Errorf("error parsing signed device cert: %w", err) + } + chain = append([]*x509.Certificate{cert}, chain...) + return chain, nil }, - TO0Responder: &fdo.TO0Server{ - Session: conf.State, - RVBlobs: conf.State, + AutoExtend: conf.State, + RvInfo: func(context.Context, *fdo.Voucher) ([][]protocol.RvInstruction, error) { + return [][]protocol.RvInstruction{}, nil }, - TO1Responder: &fdo.TO1Server{ - Session: conf.State, - RVBlobs: conf.State, + } + to0Responder := &fdo.TO0Server{ + Session: conf.State, + RVBlobs: conf.State, + } + to1Responder := &fdo.TO1Server{ + Session: conf.State, + RVBlobs: conf.State, + } + to2Responder := &fdo.TO2Server{ + Session: conf.State, + Vouchers: conf.State, + OwnerKeys: conf.State, + RvInfo: func(context.Context, fdo.Voucher) ([][]protocol.RvInstruction, error) { + return [][]protocol.RvInstruction{}, nil }, - TO2Responder: &fdo.TO2Server{ - Session: conf.State, - Vouchers: conf.State, - OwnerKeys: conf.State, - RvInfo: func(context.Context, fdo.Voucher) ([][]protocol.RvInstruction, error) { - return [][]protocol.RvInstruction{}, nil - }, - OwnerModules: func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { - if conf.OwnerModules == nil { - return func(yield func(string, serviceinfo.OwnerModule) bool) {} - } + OwnerModules: func(ctx context.Context, replacementGUID protocol.GUID, info string, chain []*x509.Certificate, devmod serviceinfo.Devmod, supportedMods []string) iter.Seq2[string, serviceinfo.OwnerModule] { + if conf.OwnerModules == nil { + return func(yield func(string, serviceinfo.OwnerModule) bool) {} + } - mods := conf.OwnerModules(ctx, replacementGUID, info, chain, devmod, supportedMods) - return func(yield func(string, serviceinfo.OwnerModule) bool) { - for modName, mod := range mods { - if slices.Contains(supportedMods, modName) { - if !yield(modName, mod) { - return - } + mods := conf.OwnerModules(ctx, replacementGUID, info, chain, devmod, supportedMods) + return func(yield func(string, serviceinfo.OwnerModule) bool) { + for modName, mod := range mods { + if slices.Contains(supportedMods, modName) { + if !yield(modName, mod) { + return } } } - }, - ReuseCredential: func(context.Context, fdo.Voucher) bool { return conf.Reuse }, - VerifyVoucher: func(context.Context, fdo.Voucher) error { return nil }, + } }, - T: t, + ReuseCredential: func(context.Context, fdo.Voucher) bool { return conf.Reuse }, + VerifyVoucher: func(context.Context, fdo.Voucher) error { return nil }, + } + + var transport fdo.Transport = &Transport{ + Tokens: conf.State, + DIResponder: diResponder, + TO0Responder: to0Responder, + TO1Responder: to1Responder, + TO2Responder: to2Responder, + T: t, + } + if conf.NewTransport != nil { + transport = conf.NewTransport(t, conf.State, diResponder, to0Responder, to1Responder, to2Responder) } to0 := &fdo.TO0Client{ @@ -213,7 +231,7 @@ func RunClientTestSuite(t *testing.T, conf Config) { }, } { t.Run(fmt.Sprintf("Key %q Encoding %q Exchange %q Cipher %q", table.keyType, table.keyEncoding, table.keyExchange, table.cipherSuite), func(t *testing.T) { - transport.DIResponder.DeviceInfo = func(context.Context, *custom.DeviceMfgInfo, []*x509.Certificate) (string, protocol.KeyType, protocol.KeyEncoding, error) { + diResponder.DeviceInfo = func(context.Context, *custom.DeviceMfgInfo, []*x509.Certificate) (string, protocol.KeyType, protocol.KeyEncoding, error) { return "test_device", table.keyType, table.keyEncoding, nil } diff --git a/http/debug.go b/http/debug.go index f1e6b79..1227a7a 100644 --- a/http/debug.go +++ b/http/debug.go @@ -8,6 +8,7 @@ import ( "encoding/hex" "log/slog" + "github.com/fido-device-onboard/go-fdo/cbor" "github.com/fido-device-onboard/go-fdo/cbor/cdn" ) @@ -22,3 +23,18 @@ func tryDebugNotation(b []byte) string { } return d } + +func debugUnencryptedMessage(msgType uint8, msg any) { + if debugEnabled() { + return + } + body, _ := cbor.Marshal(msg) + slog.Debug("unencrypted request", "msg", msgType, "body", tryDebugNotation(body)) +} + +func debugDecryptedMessage(msgType uint8, decrypted []byte) { + if debugEnabled() { + return + } + slog.Debug("decrypted response", "msg", msgType, "body", tryDebugNotation(decrypted)) +} diff --git a/http/handler.go b/http/handler.go index 067aeba..3e01f3b 100644 --- a/http/handler.go +++ b/http/handler.go @@ -12,8 +12,6 @@ import ( "io" "log/slog" "net/http" - "net/http/httptest" - "net/http/httputil" "strconv" "strings" "time" @@ -41,13 +39,15 @@ type Handler struct { } func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if h.Tokens == nil { + panic("token service not set") + } + // Parse message type from request URL - typ, err := strconv.ParseUint(r.PathValue("msg"), 10, 8) - if err != nil { - writeErr(w, 0, fmt.Errorf("invalid message type")) + msgType, ok := msgTypeFromPath(w, r) + if !ok { return } - msgType := uint8(typ) proto := protocol.Of(msgType) // Parse request headers @@ -107,38 +107,9 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx = h.Tokens.TokenContext(ctx, initToken) } - if debugEnabled() { - h.debugRequest(ctx, w, r, msgType, resp) - return - } - h.handleRequest(ctx, w, r, msgType, resp) -} - -func (h Handler) debugRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, msgType uint8, resp protocol.Responder) { - // Dump request - debugReq, _ := httputil.DumpRequest(r, false) - var saveBody bytes.Buffer - if _, err := saveBody.ReadFrom(r.Body); err == nil { - r.Body = io.NopCloser(&saveBody) - } - slog.Debug("request", "dump", string(bytes.TrimSpace(debugReq)), - "body", tryDebugNotation(saveBody.Bytes())) - - // Dump response - rr := httptest.NewRecorder() - h.handleRequest(ctx, rr, r, msgType, resp) - debugResp, _ := httputil.DumpResponse(rr.Result(), false) - slog.Debug("response", "dump", string(bytes.TrimSpace(debugResp)), - "body", tryDebugNotation(rr.Body.Bytes())) - - // Copy recorded response into response writer - for key, values := range rr.Header() { - for _, value := range values { - w.Header().Add(key, value) - } - } - w.WriteHeader(rr.Code) - _, _ = w.Write(rr.Body.Bytes()) + debugRequest(w, r, func(w http.ResponseWriter, r *http.Request) { + h.handleRequest(ctx, w, r, msgType, resp) + }) } func (h Handler) handleRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, msgType uint8, resp protocol.Responder) { diff --git a/http/http_go_test.go b/http/http_go_test.go new file mode 100644 index 0000000..910a883 --- /dev/null +++ b/http/http_go_test.go @@ -0,0 +1,10 @@ +// SPDX-FileCopyrightText: (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 + +//go:build !tinygo + +package http_test + +import "net/http" + +func setPathValue(req *http.Request, name, value string) { req.SetPathValue(name, value) } diff --git a/http/http_test.go b/http/http_test.go new file mode 100644 index 0000000..915fe29 --- /dev/null +++ b/http/http_test.go @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 + +package http_test + +import ( + "net/http" + "path" + "testing" + + "github.com/fido-device-onboard/go-fdo" + "github.com/fido-device-onboard/go-fdo/fdotest" + fdo_http "github.com/fido-device-onboard/go-fdo/http" + "github.com/fido-device-onboard/go-fdo/http/internal/httputil" + "github.com/fido-device-onboard/go-fdo/protocol" +) + +func TestClient(t *testing.T) { + newTransport := func(t *testing.T, tokens protocol.TokenService, di, to0, to1, to2 protocol.Responder) fdo.Transport { + return &fdo_http.Transport{ + BaseURL: "http://example.com", + Client: &http.Client{Transport: &transport{ + T: t, + Handler: &fdo_http.Handler{ + Tokens: tokens, + DIResponder: di, + TO0Responder: to0, + TO1Responder: to1, + TO2Responder: to2, + }, + }}, + } + } + + t.Run("Without Debug", func(t *testing.T) { + fdotest.RunClientTestSuite(t, fdotest.Config{ + NoDebug: true, + NewTransport: newTransport, + }) + }) + + t.Run("With Debug", func(t *testing.T) { + fdotest.RunClientTestSuite(t, fdotest.Config{NewTransport: newTransport}) + }) +} + +type transport struct { + T *testing.T + Handler http.Handler +} + +// Assume request is well-formed and ignore timeouts, retries, etc. +func (tr *transport) RoundTrip(req *http.Request) (*http.Response, error) { + setPathValue(req, "msg", path.Base(req.URL.Path)) + rr := new(httputil.ResponseRecorder) + tr.Handler.ServeHTTP(rr, req) + resp := rr.Result() + resp.Request = req + return resp, nil +} diff --git a/http/http_tinygo_test.go b/http/http_tinygo_test.go new file mode 100644 index 0000000..17f6a7f --- /dev/null +++ b/http/http_tinygo_test.go @@ -0,0 +1,10 @@ +// SPDX-FileCopyrightText: (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 + +//go:build tinygo + +package http_test + +import "net/http" + +func setPathValue(req *http.Request, name, value string) {} diff --git a/http/internal/httputil/httputil_tinygo.go b/http/internal/httputil/httputil_tinygo.go new file mode 100644 index 0000000..4359633 --- /dev/null +++ b/http/internal/httputil/httputil_tinygo.go @@ -0,0 +1,106 @@ +// SPDX-FileCopyrightText: (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 + +// Package httputil implements APIs misssing from the TinyGo stdlib. +package httputil + +import ( + "bytes" + "fmt" + "io" + "math" + "net/http" + "strconv" + "strings" +) + +// ResponseRecorder implements a simplfied version of the same type in the Go +// stdlib. +type ResponseRecorder struct { + body *bytes.Buffer + code int + + header http.Header + headerAtFirstWrite http.Header + wroteHeader bool + + result *http.Response +} + +// Header implements http.ResponseWriter. +func (rr *ResponseRecorder) Header() http.Header { + if rr.header == nil { + rr.header = make(http.Header) + } + return rr.header +} + +// Write implements http.ResponseWriter. +func (rr *ResponseRecorder) Write(p []byte) (int, error) { + if !rr.wroteHeader { + m := rr.Header() + if _, hasType := m["Content-Type"]; !hasType && m.Get("Transfer-Encoding") == "" { + m.Set("Content-Type", http.DetectContentType(p)) + } + rr.WriteHeader(200) + } + if rr.body == nil { + rr.body = bytes.NewBuffer(p) + return len(p), nil + } + return rr.body.Write(p) +} + +// WriteHeader implements http.ResponseWriter. +func (rr *ResponseRecorder) WriteHeader(statusCode int) { + if rr.wroteHeader { + return + } + + rr.code = statusCode + rr.wroteHeader = true + rr.headerAtFirstWrite = rr.Header().Clone() +} + +// Result returns the recorded response. +func (rr *ResponseRecorder) Result() *http.Response { + if rr.result != nil { + return rr.result + } + if rr.code == 0 { + rr.code = 200 + } + if rr.headerAtFirstWrite == nil { + rr.headerAtFirstWrite = rr.Header().Clone() + } + if rr.body == nil { + rr.body = bytes.NewBuffer([]byte{}) + } + + res := &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + StatusCode: rr.code, + Header: rr.headerAtFirstWrite, + } + if res.StatusCode == 0 { + res.StatusCode = 200 + } + res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) + res.Body = io.NopCloser(bytes.NewReader(rr.body.Bytes())) + res.ContentLength = func(length string) int64 { + n, err := strconv.ParseUint(strings.TrimSpace(length), 10, 63) + if err != nil { + return -1 + } + if n > math.MaxInt64 { + panic("unreachable") + } + return int64(n) + }(res.Header.Get("Content-Length")) + // Trailers are not used in FDO + + rr.result = res + return res +} diff --git a/http/transport.go b/http/transport.go index 8ff15ab..0116439 100644 --- a/http/transport.go +++ b/http/transport.go @@ -10,9 +10,7 @@ import ( "errors" "fmt" "io" - "log/slog" "net/http" - "net/http/httputil" "net/url" "path" "strconv" @@ -51,8 +49,6 @@ type Transport struct { } // Send sends a single message and receives a single response message. -// -//nolint:gocyclo func (t *Transport) Send(ctx context.Context, msgType uint8, msg any, sess kex.Session) (respType uint8, _ io.ReadCloser, _ error) { // Initialize default values if t.Client == nil { @@ -64,10 +60,7 @@ func (t *Transport) Send(ctx context.Context, msgType uint8, msg any, sess kex.S // Encrypt if a key exchange session is provided if sess != nil { - if debugEnabled() { - body, _ := cbor.Marshal(msg) - slog.Debug("unencrypted request", "msg", msgType, "body", tryDebugNotation(body)) - } + debugUnencryptedMessage(msgType, msg) var err error msg, err = sess.Encrypt(rand.Reader, msg) if err != nil { @@ -105,24 +98,12 @@ func (t *Transport) Send(ctx context.Context, msgType uint8, msg any, sess kex.S } // Perform HTTP request - if debugEnabled() { - debugReq, _ := httputil.DumpRequestOut(req, false) - slog.Debug("request", "dump", string(bytes.TrimSpace(debugReq)), - "body", tryDebugNotation(body.Bytes())) - } + debugRequestOut(req, body) resp, err := t.Client.Do(req) if err != nil { return 0, nil, fmt.Errorf("error making HTTP request for message %d: %w", msgType, err) } - if debugEnabled() { - debugResp, _ := httputil.DumpResponse(resp, false) - var saveBody bytes.Buffer - if _, err := saveBody.ReadFrom(resp.Body); err == nil { - resp.Body = io.NopCloser(&saveBody) - } - slog.Debug("response", "dump", string(bytes.TrimSpace(debugResp)), - "body", tryDebugNotation(saveBody.Bytes())) - } + debugResponse(resp) return t.handleResponse(resp, sess) } @@ -149,6 +130,10 @@ func (t *Transport) handleResponse(resp *http.Response, sess kex.Session) (msgTy } msgType = uint8(typ) case http.StatusInternalServerError: + if resp.Header.Get("Content-Type") != "application/cbor" { + _ = resp.Body.Close() + return 0, nil, fmt.Errorf("%s did not include an error message body", resp.Status) + } msgType = 255 default: _ = resp.Body.Close() @@ -189,10 +174,7 @@ func (t *Transport) handleResponse(resp *http.Response, sess kex.Session) (msgTy if err != nil { return 0, nil, fmt.Errorf("error decrypting message %d: %w", msgType, err) } - - if debugEnabled() { - slog.Debug("decrypted response", "msg", msgType, "body", tryDebugNotation(decrypted)) - } + debugDecryptedMessage(msgType, decrypted) content = io.NopCloser(bytes.NewBuffer(decrypted)) } diff --git a/http/util.go b/http/util.go new file mode 100644 index 0000000..9165efe --- /dev/null +++ b/http/util.go @@ -0,0 +1,80 @@ +// SPDX-FileCopyrightText: (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 + +//go:build !tinygo + +package http + +import ( + "bytes" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "net/http/httputil" + "strconv" +) + +func msgTypeFromPath(w http.ResponseWriter, r *http.Request) (uint8, bool) { + typ, err := strconv.ParseUint(r.PathValue("msg"), 10, 8) + if err != nil { + writeErr(w, 0, fmt.Errorf("invalid message type")) + return 0, false + } + return uint8(typ), true +} + +func debugRequest(w http.ResponseWriter, r *http.Request, handler http.HandlerFunc) { + if !debugEnabled() { + handler.ServeHTTP(w, r) + return + } + + // Dump request + debugReq, _ := httputil.DumpRequest(r, false) + var saveBody bytes.Buffer + if _, err := saveBody.ReadFrom(r.Body); err == nil { + r.Body = io.NopCloser(&saveBody) + } + slog.Debug("request", "dump", string(bytes.TrimSpace(debugReq)), + "body", tryDebugNotation(saveBody.Bytes())) + + // Dump response + rr := httptest.NewRecorder() + handler(rr, r) + debugResp, _ := httputil.DumpResponse(rr.Result(), false) + slog.Debug("response", "dump", string(bytes.TrimSpace(debugResp)), + "body", tryDebugNotation(rr.Body.Bytes())) + + // Copy recorded response into response writer + for key, values := range rr.Header() { + for _, value := range values { + w.Header().Add(key, value) + } + } + w.WriteHeader(rr.Code) + _, _ = w.Write(rr.Body.Bytes()) +} + +func debugRequestOut(req *http.Request, body *bytes.Buffer) { + if !debugEnabled() { + return + } + debugReq, _ := httputil.DumpRequestOut(req, false) + slog.Debug("request", "dump", string(bytes.TrimSpace(debugReq)), + "body", tryDebugNotation(body.Bytes())) +} + +func debugResponse(resp *http.Response) { + if !debugEnabled() { + return + } + debugResp, _ := httputil.DumpResponse(resp, false) + var saveBody bytes.Buffer + if _, err := saveBody.ReadFrom(resp.Body); err == nil { + resp.Body = io.NopCloser(&saveBody) + } + slog.Debug("response", "dump", string(bytes.TrimSpace(debugResp)), + "body", tryDebugNotation(saveBody.Bytes())) +} diff --git a/http/util_tinygo.go b/http/util_tinygo.go new file mode 100644 index 0000000..f124ee7 --- /dev/null +++ b/http/util_tinygo.go @@ -0,0 +1,156 @@ +// SPDX-FileCopyrightText: (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 + +//go:build tinygo + +package http + +import ( + "bytes" + "fmt" + "io" + "log/slog" + "net/http" + "strconv" + "strings" + + "github.com/fido-device-onboard/go-fdo/http/internal/httputil" +) + +func msgTypeFromPath(w http.ResponseWriter, r *http.Request) (uint8, bool) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return 0, false + } + path := strings.TrimPrefix(r.URL.Path, "/fdo/101/msg/") + if strings.Contains(path, "/") { + w.WriteHeader(http.StatusNotFound) + return 0, false + } + typ, err := strconv.ParseUint(path, 10, 8) + if err != nil { + writeErr(w, 0, fmt.Errorf("invalid message type")) + return 0, false + } + return uint8(typ), true +} + +func debugRequest(w http.ResponseWriter, r *http.Request, handler http.HandlerFunc) { + if !debugEnabled() { + handler(w, r) + return + } + + // Dump request + debugReq, _ := dumpRequest(r) + var saveBody bytes.Buffer + if _, err := saveBody.ReadFrom(r.Body); err == nil { + r.Body = io.NopCloser(&saveBody) + } + slog.Debug("request", "dump", string(bytes.TrimSpace(debugReq)), + "body", tryDebugNotation(saveBody.Bytes())) + + // Dump response + rr := new(httputil.ResponseRecorder) + handler(rr, r) + resp := rr.Result() + debugResp, _ := dumpResponse(resp) + respBody, _ := io.ReadAll(resp.Body) + slog.Debug("response", "dump", string(bytes.TrimSpace(debugResp)), + "body", tryDebugNotation(respBody)) + + // Copy recorded response into response writer + for key, values := range rr.Header() { + for _, value := range values { + w.Header().Add(key, value) + } + } + w.WriteHeader(resp.StatusCode) + _, _ = w.Write(respBody) +} + +func debugRequestOut(req *http.Request, body *bytes.Buffer) { + if !debugEnabled() { + return + } + + // Unlike httputil.DumpRequestOut, this does not use an actual HTTP + // transport to ensure that the output has all relevant headers updated and + // canonicalized. Improvements are welcome. + debugReq, _ := dumpRequest(req) + slog.Debug("request", "dump", string(bytes.TrimSpace(debugReq)), + "body", tryDebugNotation(body.Bytes())) +} + +func debugResponse(resp *http.Response) { + if !debugEnabled() { + return + } + + var saveBody bytes.Buffer + if _, err := saveBody.ReadFrom(resp.Body); err == nil { + _ = resp.Body.Close() + resp.Body = io.NopCloser(&saveBody) + } + debugResp, _ := dumpResponse(resp) + slog.Debug("response", "dump", string(bytes.TrimSpace(debugResp)), + "body", tryDebugNotation(saveBody.Bytes())) +} + +func dumpRequest(req *http.Request) ([]byte, error) { + var out bytes.Buffer + + fmt.Fprintf(&out, "%s %s HTTP/%d.%d\r\n", req.Method, req.RequestURI, req.ProtoMajor, req.ProtoMinor) + + absRequestURI := strings.HasPrefix(req.RequestURI, "http://") || strings.HasPrefix(req.RequestURI, "https://") + if !absRequestURI { + host := req.Host + if host == "" && req.URL != nil { + host = req.URL.Host + } + if host != "" { + fmt.Fprintf(&out, "Host: %s\r\n", host) + } + } + + if len(req.TransferEncoding) > 0 { + fmt.Fprintf(&out, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ",")) + } + + if err := req.Header.WriteSubset(&out, map[string]bool{ + "Transfer-Encoding": true, + "Trailer": true, + }); err != nil { + return nil, err + } + + _, _ = io.WriteString(&out, "\r\n") + + return out.Bytes(), nil +} + +var errNoBody = fmt.Errorf("no body") + +type failureToReadBody struct{} + +func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody } +func (failureToReadBody) Close() error { return nil } + +func dumpResponse(resp *http.Response) ([]byte, error) { + saveBody := resp.Body + defer func() { resp.Body = saveBody }() + + var out bytes.Buffer + savecl := resp.ContentLength + if resp.ContentLength == 0 { + resp.Body = io.NopCloser(strings.NewReader("")) + } else { + resp.Body = failureToReadBody{} + } + err := resp.Write(&out) + resp.ContentLength = savecl + if err != nil && err != errNoBody { + return nil, err + } + return out.Bytes(), nil +} diff --git a/sqlite/sqlite.go b/sqlite/sqlite.go index 168a34b..51ff476 100644 --- a/sqlite/sqlite.go +++ b/sqlite/sqlite.go @@ -19,15 +19,10 @@ import ( "fmt" "io" "maps" - "path/filepath" "slices" "strings" "time" - "github.com/ncruces/go-sqlite3/driver" // Load database/sql driver - _ "github.com/ncruces/go-sqlite3/embed" // Load sqlite WASM binary - _ "github.com/ncruces/go-sqlite3/vfs/xts" // Encryption VFS - "github.com/fido-device-onboard/go-fdo" "github.com/fido-device-onboard/go-fdo/cbor" "github.com/fido-device-onboard/go-fdo/cose" @@ -44,27 +39,8 @@ type DB struct { db *sql.DB } -// Open creates or opens a SQLite database file using a single non-pooled -// connection. If a password is specified, then the xts VFS will be used -// with a text key. -func Open(filename, password string) (*DB, error) { - var query string - if password != "" { - query += fmt.Sprintf("?vfs=xts&_pragma=textkey(%q)&_pragma=temp_store(memory)", password) - } - connector, err := (&driver.SQLite{}).OpenConnector("file:" + filepath.Clean(filename) + query) - if err != nil { - return nil, fmt.Errorf("error creating sqlite connector: %w", err) - } - db := sql.OpenDB(connector) - if err := Init(db); err != nil { - return nil, err - } - return New(db), nil -} - -// New creates a DB. The expected tables must already be created and pragmas -// must already be set, including foreign_keys=ON. +// New creates a DB. The expected tables must be created and FOREIGN_KEYS must +// be enabled before the database is used for FDO server state. func New(db *sql.DB) *DB { return &DB{db: db} } // Init ensures all tables are created and pragma are set. It does not @@ -75,7 +51,6 @@ func New(db *sql.DB) *DB { return &DB{db: db} } // local file, such as Cloudflare D1. func Init(db *sql.DB) error { stmts := []string{ - `PRAGMA foreign_keys = ON`, `CREATE TABLE IF NOT EXISTS secrets ( type TEXT NOT NULL , secret BLOB NOT NULL @@ -96,6 +71,8 @@ func Init(db *sql.DB) error { , voucher BLOB NOT NULL , exp INTEGER NOT NULL )`, + `CREATE INDEX IF NOT EXISTS rv_blob_exp + ON rv_blobs(exp ASC)`, `CREATE TABLE IF NOT EXISTS sessions ( id BLOB PRIMARY KEY , protocol INTEGER NOT NULL @@ -155,6 +132,7 @@ func Init(db *sql.DB) error { , cbor BLOB NOT NULL , FOREIGN KEY(session) REFERENCES sessions(id) ON DELETE CASCADE )`, + `PRAGMA foreign_keys = ON`, } for _, sql := range stmts { if _, err := db.Exec(sql); err != nil { diff --git a/sqlite/wasm.go b/sqlite/wasm.go new file mode 100644 index 0000000..957a7d5 --- /dev/null +++ b/sqlite/wasm.go @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 + +//go:build !tinygo + +// Open is not implemented for tinygo, because it requires embedding a WASM +// runtime in the binary. + +package sqlite + +import ( + "database/sql" + "fmt" + "path/filepath" + + "github.com/ncruces/go-sqlite3/driver" // Load database/sql driver + _ "github.com/ncruces/go-sqlite3/embed" // Load sqlite WASM binary + _ "github.com/ncruces/go-sqlite3/vfs/xts" // Encryption VFS +) + +// Open creates or opens a SQLite database file using a single non-pooled +// connection. If a password is specified, then the xts VFS will be used +// with a text key. +func Open(filename, password string) (*DB, error) { + query := "?_pragma=foreign_keys(on)" + if password != "" { + query += fmt.Sprintf("&vfs=xts&_pragma=textkey(%q)&_pragma=temp_store(memory)", password) + } + connector, err := (&driver.SQLite{}).OpenConnector("file:" + filepath.Clean(filename) + query) + if err != nil { + return nil, fmt.Errorf("error creating sqlite connector: %w", err) + } + db := sql.OpenDB(connector) + if err := Init(db); err != nil { + return nil, err + } + return New(db), nil +} diff --git a/to0.go b/to0.go index d8fab00..06e2262 100644 --- a/to0.go +++ b/to0.go @@ -65,7 +65,7 @@ func (c *TO0Client) hello(ctx context.Context, transport Transport) (protocol.No // Make request typ, resp, err := transport.Send(ctx, protocol.TO0HelloMsgType, msg, nil) if err != nil { - return protocol.Nonce{}, fmt.Errorf("error sending TO0.Hello: %w", err) + return protocol.Nonce{}, fmt.Errorf("TO0.Hello: %w", err) } defer func() { _ = resp.Close() }() @@ -182,7 +182,7 @@ func (c *TO0Client) ownerSign(ctx context.Context, transport Transport, guid pro // Make request typ, resp, err := transport.Send(ctx, protocol.TO0OwnerSignMsgType, msg, nil) if err != nil { - return 0, fmt.Errorf("error sending TO0.OwnerSign: %w", err) + return 0, fmt.Errorf("TO0.OwnerSign: %w", err) } defer func() { _ = resp.Close() }() diff --git a/to1.go b/to1.go index 346a080..990d690 100644 --- a/to1.go +++ b/to1.go @@ -80,7 +80,7 @@ func helloRv(ctx context.Context, transport Transport, cred DeviceCredential, ke // Make request typ, resp, err := transport.Send(ctx, protocol.TO1HelloRVMsgType, msg, nil) if err != nil { - return protocol.Nonce{}, fmt.Errorf("error sending TO1.HelloRV: %w", err) + return protocol.Nonce{}, fmt.Errorf("TO1.HelloRV: %w", err) } defer func() { _ = resp.Close() }() @@ -157,7 +157,7 @@ func proveToRv(ctx context.Context, transport Transport, cred DeviceCredential, // Make request typ, resp, err := transport.Send(ctx, protocol.TO1ProveToRVMsgType, msg, nil) if err != nil { - return nil, fmt.Errorf("error sending TO1.ProveToRV: %w", err) + return nil, fmt.Errorf("TO1.ProveToRV: %w", err) } defer func() { _ = resp.Close() }() diff --git a/to2.go b/to2.go index eb43e48..e41bbcc 100644 --- a/to2.go +++ b/to2.go @@ -706,7 +706,7 @@ func sendNextOVEntry(ctx context.Context, transport Transport, i int) (*cose.Sig // Make request typ, resp, err := transport.Send(ctx, protocol.TO2GetOVNextEntryMsgType, msg, nil) if err != nil { - return nil, fmt.Errorf("error sending TO2.GetOVNextEntry: %w", err) + return nil, fmt.Errorf("TO2.GetOVNextEntry: %w", err) } defer func() { _ = resp.Close() }() @@ -811,7 +811,7 @@ func proveDevice(ctx context.Context, transport Transport, proveDeviceNonce prot // Make request typ, resp, err := transport.Send(ctx, protocol.TO2ProveDeviceMsgType, msg, kex.DecryptOnly{Session: sess}) if err != nil { - return protocol.Nonce{}, nil, fmt.Errorf("error sending TO2.ProveDevice: %w", err) + return protocol.Nonce{}, nil, fmt.Errorf("TO2.ProveDevice: %w", err) } defer func() { _ = resp.Close() }() @@ -1046,7 +1046,7 @@ func sendReadyServiceInfo(ctx context.Context, transport Transport, alg protocol // Make request typ, resp, err := transport.Send(ctx, protocol.TO2DeviceServiceInfoReadyMsgType, msg, sess) if err != nil { - return 0, fmt.Errorf("error sending TO2.DeviceServiceInfoReady: %w", err) + return 0, fmt.Errorf("TO2.DeviceServiceInfoReady: %w", err) } defer func() { _ = resp.Close() }() @@ -1308,7 +1308,7 @@ func sendDone(ctx context.Context, transport Transport, proveDvNonce, setupDvNon // Make request typ, resp, err := transport.Send(ctx, protocol.TO2DoneMsgType, msg, sess) if err != nil { - return fmt.Errorf("error sending TO2.Done: %w", err) + return fmt.Errorf("TO2.Done: %w", err) } defer func() { _ = resp.Close() }() @@ -1419,7 +1419,7 @@ func sendDeviceServiceInfo(ctx context.Context, transport Transport, msg deviceS // Make request typ, resp, err := transport.Send(ctx, protocol.TO2DeviceServiceInfoMsgType, msg, sess) if err != nil { - return nil, fmt.Errorf("error sending TO2.DeviceServiceInfo: %w", err) + return nil, fmt.Errorf("TO2.DeviceServiceInfo: %w", err) } defer func() { _ = resp.Close() }()