From 3d10af34e80d218ea8c8049d4231e6fead20f5cd Mon Sep 17 00:00:00 2001 From: Ryan Turner Date: Fri, 13 Dec 2024 19:50:34 -0800 Subject: [PATCH 1/3] Remove github.com/zeebo/errs dependency We don't really use this dependency for much other than to group some errors together with a common error message prefix. The same can now be accomplished with a couple custom error types and the `errors` standard library package. This package also wasn't consistently adopted throughout the project, so at this point it's probably better to just rely on the standard library functionality, since it's sufficient for the project's use cases. Signed-off-by: Ryan Turner --- cmd/spire-server/cli/bundle/common.go | 5 +- go.mod | 2 +- pkg/agent/attestor/node/node.go | 9 +- pkg/agent/endpoints/sdsv3/handler.go | 9 +- pkg/agent/endpoints/workload/handler.go | 5 +- pkg/agent/plugin/nodeattestor/k8spsat/psat.go | 6 +- pkg/agent/plugin/nodeattestor/k8ssat/sat.go | 5 +- pkg/common/bundleutil/unmarshal.go | 13 +- pkg/common/catalog/builtin.go | 3 +- pkg/common/catalog/closers.go | 11 +- pkg/common/catalog/external.go | 3 +- pkg/common/cryptoutil/keys.go | 7 +- pkg/common/jwtsvid/common.go | 5 +- pkg/common/jwtsvid/validate.go | 21 +- pkg/common/jwtutil/keyset.go | 25 +- pkg/common/plugin/aws/iid.go | 11 +- pkg/common/plugin/azure/msi.go | 27 +- pkg/common/profiling/dumpers.go | 6 +- pkg/common/util/csr.go | 3 +- pkg/server/bundle/client/client.go | 7 +- pkg/server/bundle/client/manager_test.go | 4 +- pkg/server/bundle/client/updater.go | 3 +- pkg/server/ca/manager/journal.go | 7 +- pkg/server/ca/manager/manager.go | 18 +- pkg/server/ca/manager/slot.go | 12 +- pkg/server/ca/rotator/rotator.go | 3 +- pkg/server/datastore/sqlstore/errors.go | 118 +++++++ pkg/server/datastore/sqlstore/errors_test.go | 74 +++++ pkg/server/datastore/sqlstore/migration.go | 36 +-- pkg/server/datastore/sqlstore/mysql.go | 4 +- pkg/server/datastore/sqlstore/sqlite.go | 6 +- pkg/server/datastore/sqlstore/sqlstore.go | 300 +++++++++--------- .../datastore/sqlstore/sqlstore_test.go | 215 ++++++++----- pkg/server/datastore/sqlstore/stmt_cache.go | 2 +- pkg/server/endpoints/bundle/acme_auth.go | 4 +- pkg/server/endpoints/bundle/server.go | 5 +- .../identityprovider/identityprovider.go | 3 +- support/oidc-discovery-provider/config.go | 35 +- support/oidc-discovery-provider/main.go | 8 +- support/oidc-discovery-provider/main_posix.go | 22 +- .../oidc-discovery-provider/main_windows.go | 20 +- support/oidc-discovery-provider/server_api.go | 3 +- .../oidc-discovery-provider/workload_api.go | 7 +- 43 files changed, 661 insertions(+), 431 deletions(-) create mode 100644 pkg/server/datastore/sqlstore/errors.go create mode 100644 pkg/server/datastore/sqlstore/errors_test.go diff --git a/cmd/spire-server/cli/bundle/common.go b/cmd/spire-server/cli/bundle/common.go index a9335a37b7..343f63d0be 100644 --- a/cmd/spire-server/cli/bundle/common.go +++ b/cmd/spire-server/cli/bundle/common.go @@ -17,7 +17,6 @@ import ( "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/cmd/spire-server/util" - "github.com/zeebo/errs" ) const ( @@ -78,7 +77,7 @@ func printBundle(out io.Writer, bundle *types.Bundle) error { docBytes, err := b.Marshal() if err != nil { - return errs.Wrap(err) + return err } var o bytes.Buffer @@ -87,7 +86,7 @@ func printBundle(out io.Writer, bundle *types.Bundle) error { } if _, err := fmt.Fprintln(out, o.String()); err != nil { - return errs.Wrap(err) + return err } return nil diff --git a/go.mod b/go.mod index 22c17f693d..3252a2c463 100644 --- a/go.mod +++ b/go.mod @@ -78,7 +78,6 @@ require ( github.com/stretchr/testify v1.10.0 github.com/uber-go/tally/v4 v4.1.16 github.com/valyala/fastjson v1.6.4 - github.com/zeebo/errs v1.4.0 golang.org/x/crypto v0.31.0 golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 golang.org/x/net v0.32.0 @@ -290,6 +289,7 @@ require ( github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/yashtewari/glob-intersection v0.2.0 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect + github.com/zeebo/errs v1.4.0 // indirect go.mongodb.org/mongo-driver v1.14.0 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/detectors/gcp v1.29.0 // indirect diff --git a/pkg/agent/attestor/node/node.go b/pkg/agent/attestor/node/node.go index c7d0cdca3e..da7f024f98 100644 --- a/pkg/agent/attestor/node/node.go +++ b/pkg/agent/attestor/node/node.go @@ -28,7 +28,6 @@ import ( "github.com/spiffe/spire/pkg/common/tlspolicy" "github.com/spiffe/spire/pkg/common/util" "github.com/spiffe/spire/pkg/common/x509util" - "github.com/zeebo/errs" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) @@ -101,7 +100,7 @@ func (a *attestor) Attest(ctx context.Context) (res *AttestationResult, err erro // This is a bizarre case where we have an SVID but were unable to // load a bundle from the cache which suggests some tampering with the // cache on disk. - return nil, errs.New("SVID loaded but no bundle in cache") + return nil, errors.New("SVID loaded but no bundle in cache") default: log.WithField(telemetry.SPIFFEID, svid[0].URIs[0].String()).Info("SVID loaded") } @@ -265,7 +264,7 @@ func (a *attestor) serverConn(ctx context.Context, bundle *spiffebundle.Bundle) if !a.c.InsecureBootstrap { // We shouldn't get here since loadBundle() should fail if the bundle // is empty, but just in case... - return nil, errs.New("no bundle and not doing insecure bootstrap") + return nil, errors.New("no bundle and not doing insecure bootstrap") } // Insecure bootstrapping. Do not verify the server chain but rather do a @@ -279,7 +278,7 @@ func (a *attestor) serverConn(ctx context.Context, bundle *spiffebundle.Bundle) if len(rawCerts) == 0 { // This is not really possible without a catastrophic bug // creeping into the TLS stack. - return errs.New("server chain is unexpectedly empty") + return errors.New("server chain is unexpectedly empty") } expectedServerID, err := idutil.ServerID(a.c.TrustDomain) @@ -292,7 +291,7 @@ func (a *attestor) serverConn(ctx context.Context, bundle *spiffebundle.Bundle) return err } if len(serverCert.URIs) != 1 || serverCert.URIs[0].String() != expectedServerID.String() { - return errs.New("expected server SPIFFE ID %q; got %q", expectedServerID, serverCert.URIs) + return fmt.Errorf("expected server SPIFFE ID %q; got %q", expectedServerID, serverCert.URIs) } return nil }, diff --git a/pkg/agent/endpoints/sdsv3/handler.go b/pkg/agent/endpoints/sdsv3/handler.go index 664e9c9f85..64188a4ae1 100644 --- a/pkg/agent/endpoints/sdsv3/handler.go +++ b/pkg/agent/endpoints/sdsv3/handler.go @@ -22,7 +22,6 @@ import ( "github.com/spiffe/spire/pkg/common/pemutil" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/proto/spire/common" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/anypb" @@ -99,7 +98,7 @@ func (h *Handler) StreamSecrets(stream secret_v3.SecretDiscoveryService_StreamSe }() var versionCounter int64 - var versionInfo = strconv.FormatInt(versionCounter, 10) + versionInfo := strconv.FormatInt(versionCounter, 10) var lastNonce string var lastNode *core_v3.Node var upd *cache.WorkloadUpdate @@ -150,7 +149,7 @@ func (h *Handler) StreamSecrets(stream secret_v3.SecretDiscoveryService_StreamSe // We need to send updates if the requested resource list has changed // either explicitly, or implicitly because this is the first request. - var sendUpdates = lastReq == nil || subListChanged(lastReq.ResourceNames, newReq.ResourceNames) + sendUpdates := lastReq == nil || subListChanged(lastReq.ResourceNames, newReq.ResourceNames) // save request so that all future workload updates lead to SDS updates for the last request lastReq = newReq @@ -206,7 +205,7 @@ func subListChanged(oldSubs []string, newSubs []string) (b bool) { if len(oldSubs) != len(newSubs) { return true } - var subMap = make(map[string]bool) + subMap := make(map[string]bool) for _, sub := range oldSubs { subMap[sub] = true } @@ -582,7 +581,7 @@ func nextNonce() (string, error) { b := make([]byte, 4) _, err := rand.Read(b) if err != nil { - return "", errs.Wrap(err) + return "", err } return hex.EncodeToString(b), nil } diff --git a/pkg/agent/endpoints/workload/handler.go b/pkg/agent/endpoints/workload/handler.go index 9f191a1471..68cf81087a 100644 --- a/pkg/agent/endpoints/workload/handler.go +++ b/pkg/agent/endpoints/workload/handler.go @@ -22,7 +22,6 @@ import ( "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/pkg/common/x509util" "github.com/spiffe/spire/proto/spire/common" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" @@ -512,12 +511,12 @@ func keyStoreFromBundles(bundles []*spiffebundle.Bundle) (jwtsvid.KeyStore, erro func structFromValues(values map[string]any) (*structpb.Struct, error) { valuesJSON, err := json.Marshal(values) if err != nil { - return nil, errs.Wrap(err) + return nil, err } s := new(structpb.Struct) if err := protojson.Unmarshal(valuesJSON, s); err != nil { - return nil, errs.Wrap(err) + return nil, err } return s, nil diff --git a/pkg/agent/plugin/nodeattestor/k8spsat/psat.go b/pkg/agent/plugin/nodeattestor/k8spsat/psat.go index 20e33c4c84..47f95ba21b 100644 --- a/pkg/agent/plugin/nodeattestor/k8spsat/psat.go +++ b/pkg/agent/plugin/nodeattestor/k8spsat/psat.go @@ -3,6 +3,7 @@ package k8spsat import ( "context" "encoding/json" + "fmt" "os" "sync" @@ -12,7 +13,6 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/k8s" "github.com/spiffe/spire/pkg/common/pluginconf" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -145,10 +145,10 @@ func (p *AttestorPlugin) getConfig() (*attestorConfig, error) { func loadTokenFromFile(path string) (string, error) { data, err := os.ReadFile(path) if err != nil { - return "", errs.Wrap(err) + return "", err } if len(data) == 0 { - return "", errs.New("%q is empty", path) + return "", fmt.Errorf("%q is empty", path) } return string(data), nil } diff --git a/pkg/agent/plugin/nodeattestor/k8ssat/sat.go b/pkg/agent/plugin/nodeattestor/k8ssat/sat.go index bce6fd91e6..d93d39a1d9 100644 --- a/pkg/agent/plugin/nodeattestor/k8ssat/sat.go +++ b/pkg/agent/plugin/nodeattestor/k8ssat/sat.go @@ -14,7 +14,6 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/k8s" "github.com/spiffe/spire/pkg/common/pluginconf" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -148,10 +147,10 @@ func (p *AttestorPlugin) getConfig() (*attestorConfig, error) { func loadTokenFromFile(path string) (string, error) { data, err := os.ReadFile(path) if err != nil { - return "", errs.Wrap(err) + return "", err } if len(data) == 0 { - return "", errs.New("%q is empty", path) + return "", fmt.Errorf("%q is empty", path) } return string(data), nil } diff --git a/pkg/common/bundleutil/unmarshal.go b/pkg/common/bundleutil/unmarshal.go index c49fbadcb2..4173e44b3e 100644 --- a/pkg/common/bundleutil/unmarshal.go +++ b/pkg/common/bundleutil/unmarshal.go @@ -8,7 +8,6 @@ import ( "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" "github.com/spiffe/go-spiffe/v2/spiffeid" - "github.com/zeebo/errs" ) func Decode(trustDomain spiffeid.TrustDomain, r io.Reader) (*spiffebundle.Bundle, error) { @@ -22,7 +21,7 @@ func Decode(trustDomain spiffeid.TrustDomain, r io.Reader) (*spiffebundle.Bundle func Unmarshal(trustDomain spiffeid.TrustDomain, data []byte) (*spiffebundle.Bundle, error) { doc := new(bundleDoc) if err := json.Unmarshal(data, doc); err != nil { - return nil, errs.Wrap(err) + return nil, err } return unmarshal(trustDomain, doc) } @@ -35,20 +34,20 @@ func unmarshal(trustDomain spiffeid.TrustDomain, doc *bundleDoc) (*spiffebundle. switch key.Use { case x509SVIDUse: if len(key.Certificates) != 1 { - return nil, errs.New("expected a single certificate in x509-svid entry %d; got %d", i, len(key.Certificates)) + return nil, fmt.Errorf("expected a single certificate in x509-svid entry %d; got %d", i, len(key.Certificates)) } bundle.AddX509Authority(key.Certificates[0]) case jwtSVIDUse: if key.KeyID == "" { - return nil, errs.New("missing key ID in jwt-svid entry %d", i) + return nil, fmt.Errorf("missing key ID in jwt-svid entry %d", i) } if err := bundle.AddJWTAuthority(key.KeyID, key.Key); err != nil { - return nil, errs.New("failed to add jwt-svid entry %d: %v", i, err) + return nil, fmt.Errorf("failed to add jwt-svid entry %d: %v", i, err) } case "": - return nil, errs.New("missing use for key entry %d", i) + return nil, fmt.Errorf("missing use for key entry %d", i) default: - return nil, errs.New("unrecognized use %q for key entry %d", key.Use, i) + return nil, fmt.Errorf("unrecognized use %q for key entry %d", key.Use, i) } } diff --git a/pkg/common/catalog/builtin.go b/pkg/common/catalog/builtin.go index ae246e8164..919681fdba 100644 --- a/pkg/common/catalog/builtin.go +++ b/pkg/common/catalog/builtin.go @@ -11,7 +11,6 @@ import ( "github.com/spiffe/spire-plugin-sdk/pluginsdk" "github.com/spiffe/spire-plugin-sdk/private" "github.com/spiffe/spire/pkg/common/log" - "github.com/zeebo/errs" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -147,7 +146,7 @@ func startPipeServer(server *grpc.Server, log logrus.FieldLogger) (_ *pipeConn, // Dial the server conn, err := grpc.Dial("IGNORED", grpc.WithBlock(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(pipeNet.DialContext)) //nolint: staticcheck // It is going to be resolved on #5152 if err != nil { - return nil, errs.Wrap(err) + return nil, err } closers = append(closers, conn) diff --git a/pkg/common/catalog/closers.go b/pkg/common/catalog/closers.go index 4e418ca905..dc7571dc16 100644 --- a/pkg/common/catalog/closers.go +++ b/pkg/common/catalog/closers.go @@ -1,10 +1,10 @@ package catalog import ( + "errors" "io" "time" - "github.com/zeebo/errs" "google.golang.org/grpc" ) @@ -12,11 +12,14 @@ type closerGroup []io.Closer func (cs closerGroup) Close() error { // Close in reverse order. - var errs errs.Group + var errs error for i := len(cs) - 1; i >= 0; i-- { - errs.Add(cs[i].Close()) + if err := cs[i].Close(); err != nil { + errs = errors.Join(errs, err) + } } - return errs.Err() + + return errs } type closerFunc func() diff --git a/pkg/common/catalog/external.go b/pkg/common/catalog/external.go index 1a65b19f53..177de77b59 100644 --- a/pkg/common/catalog/external.go +++ b/pkg/common/catalog/external.go @@ -14,7 +14,6 @@ import ( "github.com/spiffe/spire-plugin-sdk/pluginsdk" "github.com/spiffe/spire-plugin-sdk/private" "github.com/spiffe/spire/pkg/common/log" - "github.com/zeebo/errs" "google.golang.org/grpc" ) @@ -154,7 +153,7 @@ func (p *hcClientPlugin) GRPCClient(ctx context.Context, b *goplugin.GRPCBroker, // does not work yet anyway, so it is a moot point. listener, err := b.Accept(private.HostServiceProviderID) if err != nil { - return nil, errs.Wrap(err) + return nil, err } server := newHostServer(p.config.Log, p.config.Name, p.config.HostServices) diff --git a/pkg/common/cryptoutil/keys.go b/pkg/common/cryptoutil/keys.go index db73567185..fa4a1e938a 100644 --- a/pkg/common/cryptoutil/keys.go +++ b/pkg/common/cryptoutil/keys.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/go-jose/go-jose/v4" - "github.com/zeebo/errs" ) func RSAPublicKeyEqual(a, b *rsa.PublicKey) bool { @@ -58,7 +57,7 @@ func JoseAlgFromPublicKey(publicKey any) (jose.SignatureAlgorithm, error) { case *rsa.PublicKey: // Prevent the use of keys smaller than 2048 bits if publicKey.Size() < 256 { - return "", errs.New("unsupported RSA key size: %d", publicKey.Size()) + return "", fmt.Errorf("unsupported RSA key size: %d", publicKey.Size()) } alg = jose.RS256 case *ecdsa.PublicKey: @@ -69,10 +68,10 @@ func JoseAlgFromPublicKey(publicKey any) (jose.SignatureAlgorithm, error) { case 384: alg = jose.ES384 default: - return "", errs.New("unable to determine signature algorithm for EC public key size %d", params.BitSize) + return "", fmt.Errorf("unable to determine signature algorithm for EC public key size %d", params.BitSize) } default: - return "", errs.New("unable to determine signature algorithm for public key type %T", publicKey) + return "", fmt.Errorf("unable to determine signature algorithm for public key type %T", publicKey) } return alg, nil } diff --git a/pkg/common/jwtsvid/common.go b/pkg/common/jwtsvid/common.go index 6d529bedbf..b1e84e30a3 100644 --- a/pkg/common/jwtsvid/common.go +++ b/pkg/common/jwtsvid/common.go @@ -5,18 +5,17 @@ import ( "time" "github.com/go-jose/go-jose/v4/jwt" - "github.com/zeebo/errs" ) func GetTokenExpiry(token string) (time.Time, time.Time, error) { tok, err := jwt.ParseSigned(token, AllowedSignatureAlgorithms) if err != nil { - return time.Time{}, time.Time{}, errs.Wrap(err) + return time.Time{}, time.Time{}, err } claims := jwt.Claims{} if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { - return time.Time{}, time.Time{}, errs.Wrap(err) + return time.Time{}, time.Time{}, err } if claims.IssuedAt == nil { return time.Time{}, time.Time{}, errors.New("JWT missing iat claim") diff --git a/pkg/common/jwtsvid/validate.go b/pkg/common/jwtsvid/validate.go index dce51831d5..7ee3f16e17 100644 --- a/pkg/common/jwtsvid/validate.go +++ b/pkg/common/jwtsvid/validate.go @@ -9,7 +9,6 @@ import ( "github.com/go-jose/go-jose/v4/jwt" "github.com/spiffe/go-spiffe/v2/spiffeid" - "github.com/zeebo/errs" ) type KeyStore interface { @@ -41,17 +40,17 @@ func (t *keyStore) FindPublicKey(_ context.Context, td spiffeid.TrustDomain, key func ValidateToken(ctx context.Context, token string, keyStore KeyStore, audience []string) (spiffeid.ID, map[string]any, error) { tok, err := jwt.ParseSigned(token, AllowedSignatureAlgorithms) if err != nil { - return spiffeid.ID{}, nil, errs.New("unable to parse JWT token: %v", err) + return spiffeid.ID{}, nil, fmt.Errorf("unable to parse JWT token: %v", err) } if len(tok.Headers) != 1 { - return spiffeid.ID{}, nil, errs.New("expected a single token header; got %d", len(tok.Headers)) + return spiffeid.ID{}, nil, fmt.Errorf("expected a single token header; got %d", len(tok.Headers)) } // Obtain the key ID from the header keyID := tok.Headers[0].KeyID if keyID == "" { - return spiffeid.ID{}, nil, errs.New("token header missing key id") + return spiffeid.ID{}, nil, fmt.Errorf("token header missing key id") } // Parse out the unverified claims. We need to look up the key by the trust @@ -59,14 +58,14 @@ func ValidateToken(ctx context.Context, token string, keyStore KeyStore, audienc // when creating the generic map of claims that we return to the caller. var claims jwt.Claims if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { - return spiffeid.ID{}, nil, errs.Wrap(err) + return spiffeid.ID{}, nil, err } if claims.Subject == "" { - return spiffeid.ID{}, nil, errs.New("token missing subject claim") + return spiffeid.ID{}, nil, errors.New("token missing subject claim") } spiffeID, err := spiffeid.FromString(claims.Subject) if err != nil { - return spiffeid.ID{}, nil, errs.New("token has in invalid subject claim: %v", err) + return spiffeid.ID{}, nil, fmt.Errorf("token has in invalid subject claim: %v", err) } // Construct the trust domain id from the SPIFFE ID and look up key by ID @@ -78,7 +77,7 @@ func ValidateToken(ctx context.Context, token string, keyStore KeyStore, audienc // Now obtain the generic claims map verified using the obtained key claimsMap := make(map[string]any) if err := tok.Claims(key, &claimsMap); err != nil { - return spiffeid.ID{}, nil, errs.Wrap(err) + return spiffeid.ID{}, nil, err } // Now that the signature over the claims has been verified, validate the @@ -90,11 +89,11 @@ func ValidateToken(ctx context.Context, token string, keyStore KeyStore, audienc // Convert expected validation errors for pretty errors switch { case errors.Is(err, jwt.ErrExpired): - err = errs.New("token has expired") + err = errors.New("token has expired") case errors.Is(err, jwt.ErrInvalidAudience): - err = errs.New("expected audience in %q (audience=%q)", audience, claims.Audience) + err = fmt.Errorf("expected audience in %q (audience=%q)", audience, claims.Audience) default: - err = errs.Wrap(err) + err = err } return spiffeid.ID{}, nil, err } diff --git a/pkg/common/jwtutil/keyset.go b/pkg/common/jwtutil/keyset.go index a188fe7b29..56f07baf30 100644 --- a/pkg/common/jwtutil/keyset.go +++ b/pkg/common/jwtutil/keyset.go @@ -3,6 +3,8 @@ package jwtutil import ( "context" "encoding/json" + "errors" + "fmt" "io" "net/http" "net/url" @@ -12,7 +14,6 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/sirupsen/logrus" - "github.com/zeebo/errs" ) const ( @@ -34,7 +35,7 @@ type OIDCIssuer string func (c OIDCIssuer) GetKeySet(ctx context.Context) (*jose.JSONWebKeySet, error) { u, err := url.Parse(string(c)) if err != nil { - return nil, errs.Wrap(err) + return nil, err } u.Path = path.Join(u.Path, wellKnownOpenIDConfiguration) @@ -86,7 +87,7 @@ func (c *CachingKeySetProvider) GetKeySet(ctx context.Context) (*jose.JSONWebKey } else { logrus.WithError(err).Warn("Unable to refresh key set") if c.jwks == nil { - return nil, errs.Wrap(err) + return nil, err } } @@ -96,27 +97,27 @@ func (c *CachingKeySetProvider) GetKeySet(ctx context.Context) (*jose.JSONWebKey func DiscoverKeySetURI(ctx context.Context, configURL string) (string, error) { req, err := http.NewRequest("GET", configURL, nil) if err != nil { - return "", errs.Wrap(err) + return "", err } req = req.WithContext(ctx) resp, err := http.DefaultClient.Do(req) if err != nil { - return "", errs.Wrap(err) + return "", err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", errs.New("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) + return "", fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) } config := &struct { JWKSURI string `json:"jwks_uri"` }{} if err := json.NewDecoder(resp.Body).Decode(config); err != nil { - return "", errs.New("failed to decode configuration: %v", err) + return "", fmt.Errorf("failed to decode configuration: %v", err) } if config.JWKSURI == "" { - return "", errs.New("configuration missing JWKS URI") + return "", errors.New("configuration missing JWKS URI") } return config.JWKSURI, nil @@ -125,22 +126,22 @@ func DiscoverKeySetURI(ctx context.Context, configURL string) (string, error) { func FetchKeySet(ctx context.Context, jwksURI string) (*jose.JSONWebKeySet, error) { req, err := http.NewRequest("GET", jwksURI, nil) if err != nil { - return nil, errs.Wrap(err) + return nil, err } req = req.WithContext(ctx) resp, err := http.DefaultClient.Do(req) if err != nil { - return nil, errs.Wrap(err) + return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, errs.New("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) + return nil, fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) } jwks := new(jose.JSONWebKeySet) if err := json.NewDecoder(resp.Body).Decode(jwks); err != nil { - return nil, errs.New("failed to decode key set: %v", err) + return nil, fmt.Errorf("failed to decode key set: %v", err) } return jwks, nil diff --git a/pkg/common/plugin/aws/iid.go b/pkg/common/plugin/aws/iid.go index 8b8fcea741..6da18e5c82 100644 --- a/pkg/common/plugin/aws/iid.go +++ b/pkg/common/plugin/aws/iid.go @@ -1,19 +1,12 @@ package aws -import ( - "github.com/zeebo/errs" -) +import "fmt" const ( // PluginName for AWS IID PluginName = "aws_iid" ) -var ( - IidErrorClass = errs.Class("aws-iid") - iidError = IidErrorClass -) - // IIDAttestationData AWS IID attestation data type IIDAttestationData struct { Document string `json:"document"` @@ -23,5 +16,5 @@ type IIDAttestationData struct { // AttestationStepError error with attestation func AttestationStepError(step string, cause error) error { - return iidError.New("attempted attestation but an error occurred %s: %w", step, cause) + return fmt.Errorf("aws-iid: attempted attestation but an error occurred %s: %w", step, cause) } diff --git a/pkg/common/plugin/azure/msi.go b/pkg/common/plugin/azure/msi.go index 99356cbbc3..bb5461eb2f 100644 --- a/pkg/common/plugin/azure/msi.go +++ b/pkg/common/plugin/azure/msi.go @@ -2,6 +2,8 @@ package azure import ( "encoding/json" + "errors" + "fmt" "io" "net/http" @@ -9,7 +11,6 @@ import ( "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/agentpathtemplate" "github.com/spiffe/spire/pkg/common/idutil" - "github.com/zeebo/errs" ) const ( @@ -56,7 +57,7 @@ func (fn HTTPClientFunc) Do(req *http.Request) (*http.Response, error) { func FetchMSIToken(cl HTTPClient, resource string) (string, error) { req, err := http.NewRequest("GET", "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01", nil) if err != nil { - return "", errs.Wrap(err) + return "", err } req.Header.Add("Metadata", "true") @@ -66,11 +67,11 @@ func FetchMSIToken(cl HTTPClient, resource string) (string, error) { resp, err := cl.Do(req) if err != nil { - return "", errs.Wrap(err) + return "", err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", errs.New("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) + return "", fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) } r := struct { @@ -78,11 +79,11 @@ func FetchMSIToken(cl HTTPClient, resource string) (string, error) { }{} if err := json.NewDecoder(resp.Body).Decode(&r); err != nil { - return "", errs.New("unable to decode response: %v", err) + return "", fmt.Errorf("unable to decode response: %v", err) } if r.AccessToken == "" { - return "", errs.New("response missing access token") + return "", fmt.Errorf("response missing access token") } return r.AccessToken, nil @@ -91,31 +92,31 @@ func FetchMSIToken(cl HTTPClient, resource string) (string, error) { func FetchInstanceMetadata(cl HTTPClient) (*InstanceMetadata, error) { req, err := http.NewRequest("GET", "http://169.254.169.254/metadata/instance?api-version=2017-08-01&format=json", nil) if err != nil { - return nil, errs.Wrap(err) + return nil, err } req.Header.Add("Metadata", "true") resp, err := cl.Do(req) if err != nil { - return nil, errs.Wrap(err) + return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, errs.New("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) + return nil, fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) } metadata := new(InstanceMetadata) if err := json.NewDecoder(resp.Body).Decode(metadata); err != nil { - return nil, errs.New("unable to decode response: %v", err) + return nil, fmt.Errorf("unable to decode response: %v", err) } switch { case metadata.Compute.Name == "": - return nil, errs.New("response missing instance name") + return nil, errors.New("response missing instance name") case metadata.Compute.SubscriptionID == "": - return nil, errs.New("response missing instance subscription id") + return nil, errors.New("response missing instance subscription id") case metadata.Compute.ResourceGroupName == "": - return nil, errs.New("response missing instance resource group name") + return nil, errors.New("response missing instance resource group name") } return metadata, nil diff --git a/pkg/common/profiling/dumpers.go b/pkg/common/profiling/dumpers.go index e7fa6e5442..5e47b414fa 100644 --- a/pkg/common/profiling/dumpers.go +++ b/pkg/common/profiling/dumpers.go @@ -6,8 +6,6 @@ import ( "runtime/pprof" "runtime/trace" "strings" - - "github.com/zeebo/errs" ) const ( @@ -99,7 +97,7 @@ func (d *traceDumper) Dump(timestamp string, name string) error { d.data.Close() filename := getFilename(timestamp, d.c.Tag, name) if err := os.Rename(getTempFilename(d.c.Tag, traceProfTmpFilename), filename); err != nil { - return errs.Wrap(err) + return err } return d.Prepare() } @@ -133,7 +131,7 @@ func (d *cpuDumper) Dump(timestamp string, name string) error { d.data.Close() filename := getFilename(timestamp, d.c.Tag, name) if err := os.Rename(getTempFilename(d.c.Tag, cpuProfTmpFilename), filename); err != nil { - return errs.Wrap(err) + return err } return d.Prepare() } diff --git a/pkg/common/util/csr.go b/pkg/common/util/csr.go index 089ae61393..bdd98f7d92 100644 --- a/pkg/common/util/csr.go +++ b/pkg/common/util/csr.go @@ -7,7 +7,6 @@ import ( "net/url" "github.com/spiffe/go-spiffe/v2/spiffeid" - "github.com/zeebo/errs" ) func MakeCSR(privateKey any, spiffeID spiffeid.ID) ([]byte, error) { @@ -33,7 +32,7 @@ func MakeCSRWithoutURISAN(privateKey any) ([]byte, error) { func makeCSR(privateKey any, template *x509.CertificateRequest) ([]byte, error) { csr, err := x509.CreateCertificateRequest(rand.Reader, template, privateKey) if err != nil { - return nil, errs.Wrap(err) + return nil, err } return csr, nil } diff --git a/pkg/server/bundle/client/client.go b/pkg/server/bundle/client/client.go index 2462a0917b..3cd3d9f7fb 100644 --- a/pkg/server/bundle/client/client.go +++ b/pkg/server/bundle/client/client.go @@ -14,7 +14,6 @@ import ( "github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig" "github.com/spiffe/spire/pkg/common/bundleutil" "github.com/spiffe/spire/pkg/common/tlspolicy" - "github.com/zeebo/errs" ) type SPIFFEAuthConfig struct { @@ -92,15 +91,15 @@ func (c *client) FetchBundle(context.Context) (*spiffebundle.Bundle, error) { var hostnameError x509.HostnameError if errors.As(err, &hostnameError) && c.c.SPIFFEAuth == nil && len(hostnameError.Certificate.URIs) > 0 { if id, idErr := spiffeid.FromString(hostnameError.Certificate.URIs[0].String()); idErr == nil { - return nil, errs.New("failed to authenticate bundle endpoint using web authentication but the server certificate contains SPIFFE ID %q: maybe use https_spiffe instead of https_web: %v", id, err) + return nil, fmt.Errorf("failed to authenticate bundle endpoint using web authentication but the server certificate contains SPIFFE ID %q: maybe use https_spiffe instead of https_web: %v", id, err) } } - return nil, errs.New("failed to fetch bundle: %v", err) + return nil, fmt.Errorf("failed to fetch bundle: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, errs.New("unexpected status %d fetching bundle: %s", resp.StatusCode, tryRead(resp.Body)) + return nil, fmt.Errorf("unexpected status %d fetching bundle: %s", resp.StatusCode, tryRead(resp.Body)) } b, err := bundleutil.Decode(c.c.TrustDomain, resp.Body) diff --git a/pkg/server/bundle/client/manager_test.go b/pkg/server/bundle/client/manager_test.go index b2a4855bfc..e883e1a520 100644 --- a/pkg/server/bundle/client/manager_test.go +++ b/pkg/server/bundle/client/manager_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/x509" "errors" + "fmt" "sync" "testing" "time" @@ -17,7 +18,6 @@ import ( "github.com/spiffe/spire/test/fakes/fakedatastore" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zeebo/errs" ) func TestManagerPeriodicBundleRefresh(t *testing.T) { @@ -278,7 +278,7 @@ func newManagerTest(t *testing.T, source TrustDomainConfigSource, localBundles, go func() { defer func() { if r := recover(); r != nil { - errCh <- errs.New("%+v", r) + errCh <- fmt.Errorf("%+v", r) } }() errCh <- test.manager.Run(ctx) diff --git a/pkg/server/bundle/client/updater.go b/pkg/server/bundle/client/updater.go index 3e906d4d62..b268570f0b 100644 --- a/pkg/server/bundle/client/updater.go +++ b/pkg/server/bundle/client/updater.go @@ -10,7 +10,6 @@ import ( "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/bundleutil" "github.com/spiffe/spire/pkg/server/datastore" - "github.com/zeebo/errs" ) type BundleUpdaterConfig struct { @@ -141,7 +140,7 @@ func fetchBundleIfExists(ctx context.Context, ds datastore.DataStore, trustDomai // Load the current bundle and extract the root CA certificates bundle, err := ds.FetchBundle(ctx, trustDomain.IDString()) if err != nil { - return nil, errs.Wrap(err) + return nil, err } if bundle == nil { return nil, nil diff --git a/pkg/server/ca/manager/journal.go b/pkg/server/ca/manager/journal.go index cc280e90cc..0c72789343 100644 --- a/pkg/server/ca/manager/journal.go +++ b/pkg/server/ca/manager/journal.go @@ -14,7 +14,6 @@ import ( "github.com/spiffe/spire/pkg/server/catalog" "github.com/spiffe/spire/pkg/server/datastore" "github.com/spiffe/spire/proto/private/server/journal" - "github.com/zeebo/errs" "google.golang.org/protobuf/proto" ) @@ -125,7 +124,7 @@ func (j *Journal) AppendJWTKey(ctx context.Context, slotID string, issuedAt time pkixBytes, err := x509.MarshalPKIXPublicKey(jwtKey.Signer.Public()) if err != nil { - return errs.Wrap(err) + return err } backup := j.entries.JwtKeys @@ -273,7 +272,7 @@ func (j *Journal) findCAJournal(ctx context.Context) (*datastore.CAJournal, erro func (j *Journal) save(ctx context.Context) error { entriesBytes, err := proto.Marshal(j.entries) if err != nil { - return errs.Wrap(err) + return err } caJournalID, err := j.saveInDatastore(ctx, entriesBytes) @@ -315,7 +314,7 @@ func loadJournalFromDS(ctx context.Context, config *journalConfig) (*Journal, er j.caJournalID = caJournal.ID if err := proto.Unmarshal(caJournal.Data, j.entries); err != nil { - return nil, errs.New("unable to unmarshal entries from CA journal record: %v", err) + return nil, fmt.Errorf("unable to unmarshal entries from CA journal record: %v", err) } return j, nil } diff --git a/pkg/server/ca/manager/manager.go b/pkg/server/ca/manager/manager.go index 4aa631c41c..a6c4b9cc07 100644 --- a/pkg/server/ca/manager/manager.go +++ b/pkg/server/ca/manager/manager.go @@ -28,7 +28,6 @@ import ( "github.com/spiffe/spire/pkg/server/plugin/notifier" "github.com/spiffe/spire/proto/private/server/journal" "github.com/spiffe/spire/proto/spire/common" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -456,7 +455,6 @@ func (m *Manager) PruneBundle(ctx context.Context) (err error) { expiresBefore := m.c.Clock.Now().Add(-safetyThresholdBundle) changed, err := ds.PruneBundle(ctx, m.c.TrustDomain.IDString(), expiresBefore) - if err != nil { return fmt.Errorf("unable to prune bundle: %w", err) } @@ -478,7 +476,6 @@ func (m *Manager) PruneCAJournals(ctx context.Context) (err error) { expiresBefore := m.c.Clock.Now().Add(-safetyThresholdCAJournals) err = ds.PruneCAJournals(ctx, expiresBefore.Unix()) - if err != nil { return fmt.Errorf("unable to prune CA journals: %w", err) } @@ -735,17 +732,18 @@ func (m *Manager) notify(ctx context.Context, event string, advise bool, pre fun }(n) } - var allErrs errs.Group + var allErrs error for i := 0; i < len(notifiers); i++ { // don't select on the ctx here as we can rely on the plugins to // respond to context cancellation and return an error. if err := <-errsCh; err != nil { - allErrs.Add(err) + allErrs = errors.Join(allErrs, err) } } - if err := allErrs.Err(); err != nil { - return errs.New("one or more notifiers returned an error: %v", err) + if allErrs != nil { + return fmt.Errorf("one or more notifiers returned an error: %v", allErrs) } + return nil } @@ -755,7 +753,7 @@ func (m *Manager) fetchRequiredBundle(ctx context.Context) (*common.Bundle, erro return nil, err } if bundle == nil { - return nil, errs.New("trust domain bundle is missing") + return nil, errors.New("trust domain bundle is missing") } return bundle, nil } @@ -764,7 +762,7 @@ func (m *Manager) fetchOptionalBundle(ctx context.Context) (*common.Bundle, erro ds := m.c.Catalog.GetDataStore() bundle, err := ds.FetchBundle(ctx, m.c.TrustDomain.IDString()) if err != nil { - return nil, errs.Wrap(err) + return nil, err } return bundle, nil } @@ -1052,7 +1050,7 @@ func keyIDFromBytes(choices []byte) string { func publicKeyFromJWTKey(jwtKey *ca.JWTKey) (*common.PublicKey, error) { pkixBytes, err := x509.MarshalPKIXPublicKey(jwtKey.Signer.Public()) if err != nil { - return nil, errs.Wrap(err) + return nil, err } return &common.PublicKey{ diff --git a/pkg/server/ca/manager/slot.go b/pkg/server/ca/manager/slot.go index cbfff8c768..eb4ee7a232 100644 --- a/pkg/server/ca/manager/slot.go +++ b/pkg/server/ca/manager/slot.go @@ -19,7 +19,6 @@ import ( "github.com/spiffe/spire/pkg/server/catalog" "github.com/spiffe/spire/proto/private/server/journal" "github.com/spiffe/spire/proto/spire/common" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -274,7 +273,6 @@ func (s *SlotLoader) getJWTKeysSlots(ctx context.Context, entries []*journal.JWT // Instead, we'll rotate into a new one. func (s *SlotLoader) filterInvalidEntries(ctx context.Context, entries *journal.Entries) ([]*journal.JWTKeyEntry, []*journal.X509CAEntry, error) { bundle, err := s.fetchOptionalBundle(ctx) - if err != nil { return nil, nil, err } @@ -314,7 +312,7 @@ func (s *SlotLoader) fetchOptionalBundle(ctx context.Context) (*common.Bundle, e ds := s.Catalog.GetDataStore() bundle, err := ds.FetchBundle(ctx, s.TrustDomain.IDString()) if err != nil { - return nil, errs.Wrap(err) + return nil, err } return bundle, nil } @@ -351,14 +349,14 @@ func (s *SlotLoader) loadX509CASlotFromEntry(ctx context.Context, entry *journal cert, err := x509.ParseCertificate(entry.Certificate) if err != nil { - return nil, "", errs.New("unable to parse CA certificate: %v", err) + return nil, "", fmt.Errorf("unable to parse CA certificate: %v", err) } var upstreamChain []*x509.Certificate for _, certDER := range entry.UpstreamChain { cert, err := x509.ParseCertificate(certDER) if err != nil { - return nil, "", errs.New("unable to parse upstream chain certificate: %v", err) + return nil, "", fmt.Errorf("unable to parse upstream chain certificate: %v", err) } upstreamChain = append(upstreamChain, cert) } @@ -421,7 +419,7 @@ func (s *SlotLoader) loadJWTKeySlotFromEntry(ctx context.Context, entry *journal publicKey, err := x509.ParsePKIXPublicKey(entry.PublicKey) if err != nil { - return nil, "", errs.Wrap(err) + return nil, "", err } signer, err := s.makeSigner(ctx, jwtKeyKmKeyID(entry.SlotId)) @@ -460,7 +458,7 @@ func (s *SlotLoader) makeSigner(ctx context.Context, keyID string) (crypto.Signe case codes.NotFound: return nil, nil default: - return nil, errs.Wrap(err) + return nil, err } } diff --git a/pkg/server/ca/rotator/rotator.go b/pkg/server/ca/rotator/rotator.go index 923a020ca7..000f10494e 100644 --- a/pkg/server/ca/rotator/rotator.go +++ b/pkg/server/ca/rotator/rotator.go @@ -11,7 +11,6 @@ import ( "github.com/spiffe/spire/pkg/common/health" "github.com/spiffe/spire/pkg/common/util" "github.com/spiffe/spire/pkg/server/ca/manager" - "github.com/zeebo/errs" ) const ( @@ -138,7 +137,7 @@ func (r *Rotator) rotate(ctx context.Context) error { r.c.Log.WithError(jwtKeyErr).Error("Unable to rotate JWT key") } - return errs.Combine(x509CAErr, jwtKeyErr) + return errors.Join(x509CAErr, jwtKeyErr) } func (r *Rotator) rotateJWTKey(ctx context.Context) error { diff --git a/pkg/server/datastore/sqlstore/errors.go b/pkg/server/datastore/sqlstore/errors.go new file mode 100644 index 0000000000..364679f8d6 --- /dev/null +++ b/pkg/server/datastore/sqlstore/errors.go @@ -0,0 +1,118 @@ +package sqlstore + +import ( + "fmt" +) + +const ( + datastoreSQLErrorPrefix = "datastore-sql" + datastoreValidationErrorPrefix = "datastore-validation" +) + +type sqlError struct { + err error + msg string +} + +func newSQLError(fmtMsg string, args ...any) error { + return &sqlError{ + msg: fmt.Sprintf(fmtMsg, args...), + } +} + +func newWrappedSQLError(err error) error { + if err == nil { + return nil + } + + return &sqlError{ + err: err, + } +} + +func (s *sqlError) Error() string { + if s == nil { + return "" + } + + if s.err != nil { + return fmt.Sprintf("%s: %s", datastoreSQLErrorPrefix, s.err) + } + + return fmt.Sprintf("%s: %s", datastoreSQLErrorPrefix, s.msg) +} + +func (s *sqlError) Is(err error) bool { + if s == nil { + return false + } + + sErr, ok := err.(*sqlError) + if !ok { + return false + } + + return s.msg == sErr.msg +} + +func (s *sqlError) Unwrap() error { + if s == nil { + return nil + } + + return s.err +} + +type validationError struct { + err error + msg string +} + +func newValidationError(fmtMsg string, args ...any) error { + return &validationError{ + msg: fmt.Sprintf(fmtMsg, args...), + } +} + +func newWrappedValidationError(err error) error { + if err == nil { + return nil + } + + return &validationError{ + err: err, + } +} + +func (v *validationError) Error() string { + if v == nil { + return "" + } + + if v.err != nil { + return fmt.Sprintf("%s: %s", datastoreValidationErrorPrefix, v.err) + } + + return fmt.Sprintf("%s: %s", datastoreValidationErrorPrefix, v.msg) +} + +func (v *validationError) Is(err error) bool { + if v == nil { + return false + } + + vErr, ok := err.(*validationError) + if !ok { + return false + } + + return v.msg == vErr.msg +} + +func (v *validationError) Unwrap() error { + if v == nil { + return nil + } + + return v.err +} diff --git a/pkg/server/datastore/sqlstore/errors_test.go b/pkg/server/datastore/sqlstore/errors_test.go new file mode 100644 index 0000000000..6de7eb4012 --- /dev/null +++ b/pkg/server/datastore/sqlstore/errors_test.go @@ -0,0 +1,74 @@ +package sqlstore + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSQLError(t *testing.T) { + err := newSQLError("an error with two dynamic fields: %s, %d", "hello", 1) + assert.EqualError(t, err, "datastore-sql: an error with two dynamic fields: hello, 1") + + var sErr *sqlError + assert.True(t, errors.As(err, &sErr)) + + assert.True(t, errors.Is(err, &sqlError{ + msg: "an error with two dynamic fields: hello, 1", + })) +} + +func TestWrappedSQLError(t *testing.T) { + t.Run("nil error", func(t *testing.T) { + err := newWrappedSQLError(nil) + assert.NoError(t, err) + }) + + t.Run("non-nil error", func(t *testing.T) { + wrappedErr := errors.New("foo") + err := newWrappedSQLError(wrappedErr) + + assert.EqualError(t, err, "datastore-sql: foo") + + var sErr *sqlError + assert.True(t, errors.As(err, &sErr)) + + assert.True(t, errors.Is(err, &sqlError{ + err: wrappedErr, + })) + }) +} + +func TestValidationError(t *testing.T) { + err := newValidationError("an error with two dynamic fields: %s, %d", "hello", 1) + assert.EqualError(t, err, "datastore-validation: an error with two dynamic fields: hello, 1") + + var vErr *validationError + assert.True(t, errors.As(err, &vErr)) + + assert.True(t, errors.Is(err, &validationError{ + msg: "an error with two dynamic fields: hello, 1", + })) +} + +func TestWrappedValidationError(t *testing.T) { + t.Run("nil error", func(t *testing.T) { + err := newWrappedValidationError(nil) + assert.NoError(t, err) + }) + + t.Run("non-nil error", func(t *testing.T) { + wrappedErr := errors.New("bar") + err := newWrappedValidationError(wrappedErr) + + assert.EqualError(t, err, "datastore-validation: bar") + + var vErr *validationError + assert.True(t, errors.As(err, &vErr)) + + assert.True(t, errors.Is(err, &validationError{ + err: wrappedErr, + })) + }) +} diff --git a/pkg/server/datastore/sqlstore/migration.go b/pkg/server/datastore/sqlstore/migration.go index c9febb270a..0d8eece2c7 100644 --- a/pkg/server/datastore/sqlstore/migration.go +++ b/pkg/server/datastore/sqlstore/migration.go @@ -271,12 +271,12 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie // version before continuing, and fail if we're not. if codeVersion.Major > 1 { log.Error("Migration code needs updating for current release version") - return sqlError.New("current migration code not compatible with current release version") + return newSQLError("current migration code not compatible with current release version") } isNew := !db.HasTable(&Migration{}) if err := db.Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if isNew { @@ -285,12 +285,12 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie // ensure migrations table exists so we can check versioning in all cases if err := db.AutoMigrate(&Migration{}).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } migration := new(Migration) if err := db.Assign(Migration{}).FirstOrCreate(migration).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } schemaVersion := migration.Version @@ -300,7 +300,7 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie dbCodeVersion, err := getDBCodeVersion(*migration) if err != nil { log.WithError(err).Error("Error getting DB code version") - return sqlError.New("error getting DB code version: %v", err) + return newSQLError("error getting DB code version: %v", err) } log = log.WithField(telemetry.VersionInfo, dbCodeVersion.String()) @@ -316,7 +316,7 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie } if err := db.Model(&Migration{}).Updates(newMigration).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } } return nil @@ -325,7 +325,7 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie if disableMigration { if err = isDisabledMigrationAllowed(codeVersion, dbCodeVersion); err != nil { log.WithError(err).Error("Auto-migrate must be enabled") - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil } @@ -336,7 +336,7 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie if schemaVersion > latestSchemaVersion { if !isCompatibleCodeVersion(codeVersion, dbCodeVersion) { log.Error("Incompatible DB schema is too new for code version, upgrade SPIRE Server") - return sqlError.New("incompatible DB schema and code version") + return newSQLError("incompatible DB schema and code version") } log.Warn("DB schema is ahead of code version, upgrading SPIRE Server is recommended") return nil @@ -350,7 +350,7 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie for schemaVersion < latestSchemaVersion { tx := db.Begin() if err := tx.Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } schemaVersion, err = migrateVersion(tx, schemaVersion, log) if err != nil { @@ -358,7 +358,7 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie return err } if err := tx.Commit().Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } } @@ -401,7 +401,7 @@ func initDB(db *gorm.DB, dbType string, log logrus.FieldLogger) (err error) { log.Info("Initializing new database") tx := db.Begin() if err := tx.Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } tables := []any{ @@ -421,7 +421,7 @@ func initDB(db *gorm.DB, dbType string, log logrus.FieldLogger) (err error) { if err := tableOptionsForDialect(tx, dbType).AutoMigrate(tables...).Error; err != nil { tx.Rollback() - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if err := tx.Assign(Migration{ @@ -429,7 +429,7 @@ func initDB(db *gorm.DB, dbType string, log logrus.FieldLogger) (err error) { CodeVersion: codeVersion.String(), }).FirstOrCreate(&Migration{}).Error; err != nil { tx.Rollback() - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if err := addFederatedRegistrationEntriesRegisteredEntryIDIndex(tx); err != nil { @@ -437,7 +437,7 @@ func initDB(db *gorm.DB, dbType string, log logrus.FieldLogger) (err error) { } if err := tx.Commit().Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -461,11 +461,11 @@ func migrateVersion(tx *gorm.DB, currVersion int, log logrus.FieldLogger) (versi Version: nextVersion, CodeVersion: version.Version(), }).Error; err != nil { - return 0, sqlError.Wrap(err) + return 0, newWrappedSQLError(err) } if currVersion < lastMinorReleaseSchemaVersion { - return 0, sqlError.New("migrating from schema version %d requires a previous SPIRE release; please follow the upgrade strategy at doc/upgrading.md", currVersion) + return 0, newSQLError("migrating from schema version %d requires a previous SPIRE release; please follow the upgrade strategy at doc/upgrading.md", currVersion) } // Place all migrations handled by the current minor release here. This @@ -489,7 +489,7 @@ func migrateVersion(tx *gorm.DB, currVersion int, log logrus.FieldLogger) (versi // switch currVersion { //nolint: gocritic // No upgrade required yet, keeping switch for future additions default: - err = sqlError.New("no migration support for unknown schema version %d", currVersion) + err = newSQLError("no migration support for unknown schema version %d", currVersion) } if err != nil { return 0, err @@ -506,7 +506,7 @@ func addFederatedRegistrationEntriesRegisteredEntryIDIndex(tx *gorm.DB) error { // to introduce the index since there is no explicit struct to add tags to // so we have to manually create it. if err := tx.Table("federated_registration_entries").AddIndex("idx_federated_registration_entries_registered_entry_id", "registered_entry_id").Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil } diff --git a/pkg/server/datastore/sqlstore/mysql.go b/pkg/server/datastore/sqlstore/mysql.go index 8e626330f1..69b93acf5f 100644 --- a/pkg/server/datastore/sqlstore/mysql.go +++ b/pkg/server/datastore/sqlstore/mysql.go @@ -169,11 +169,11 @@ func hasTLSConfig(cfg *configuration) bool { func validateMySQLConfig(cfg *configuration, isReadOnly bool) error { opts, err := mysql.ParseDSN(getConnectionString(cfg, isReadOnly)) if err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if !opts.ParseTime { - return sqlError.Wrap(errors.New("invalid mysql config: missing parseTime=true param in connection_string")) + return newWrappedSQLError(errors.New("invalid mysql config: missing parseTime=true param in connection_string")) } return nil diff --git a/pkg/server/datastore/sqlstore/sqlite.go b/pkg/server/datastore/sqlstore/sqlite.go index a3e4ff56e2..c911f2920e 100644 --- a/pkg/server/datastore/sqlstore/sqlite.go +++ b/pkg/server/datastore/sqlstore/sqlite.go @@ -55,7 +55,7 @@ func openSQLite3(connString string) (*gorm.DB, error) { } db, err := gorm.Open("sqlite3", embellished) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return db, nil } @@ -74,7 +74,7 @@ func embellishSQLite3ConnString(connectionString string) (string, error) { u, err := url.Parse(connectionString) if err != nil { - return "", sqlError.Wrap(err) + return "", newWrappedSQLError(err) } switch { @@ -88,7 +88,7 @@ func embellishSQLite3ConnString(connectionString string) (string, error) { u.Opaque, u.Path = u.Path, "" case u.Scheme != "file": // only no scheme (i.e. file path) or file scheme is supported - return "", sqlError.New("unsupported scheme %q", u.Scheme) + return "", newSQLError("unsupported scheme %q", u.Scheme) } q := u.Query() diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index 76c4f3ed5c..2f631db7b9 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -30,26 +30,21 @@ import ( "github.com/spiffe/spire/pkg/server/datastore" "github.com/spiffe/spire/proto/private/server/journal" "github.com/spiffe/spire/proto/spire/common" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" ) -var ( - sqlError = errs.Class("datastore-sql") - validationError = errs.Class("datastore-validation") - validEntryIDChars = &unicode.RangeTable{ - R16: []unicode.Range16{ - {0x002d, 0x002e, 1}, // - | . - {0x0030, 0x0039, 1}, // [0-9] - {0x0041, 0x005a, 1}, // [A-Z] - {0x005f, 0x005f, 1}, // _ - {0x0061, 0x007a, 1}, // [a-z] - }, - LatinOffset: 5, - } -) +var validEntryIDChars = &unicode.RangeTable{ + R16: []unicode.Range16{ + {0x002d, 0x002e, 1}, // - | . + {0x0030, 0x0039, 1}, // [0-9] + {0x0041, 0x005a, 1}, // [A-Z] + {0x005f, 0x005f, 1}, // _ + {0x0061, 0x007a, 1}, // [a-z] + }, + LatinOffset: 5, +} const ( PluginName = "sql" @@ -104,7 +99,7 @@ type awsConfig struct { func (a *awsConfig) validate() error { if a.Region == "" { - return sqlError.New("region must be specified") + return newSQLError("region must be specified") } return nil } @@ -288,7 +283,7 @@ func (ds *Plugin) RevokeJWTKey(ctx context.Context, trustDoaminID string, author // CreateAttestedNode stores the given attested node func (ds *Plugin) CreateAttestedNode(ctx context.Context, node *common.AttestedNode) (attestedNode *common.AttestedNode, err error) { if node == nil { - return nil, sqlError.New("invalid request: missing attested node") + return nil, newSQLError("invalid request: missing attested node") } if err = ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { @@ -801,7 +796,7 @@ func (ds *Plugin) PruneCAJournals(ctx context.Context, allAuthoritiesExpireBefor func (ds *Plugin) pruneCAJournals(tx *gorm.DB, allAuthoritiesExpireBefore int64) error { var caJournals []CAJournal if err := tx.Find(&caJournals).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } checkAuthorities: @@ -884,7 +879,7 @@ func (ds *Plugin) openConnection(config *configuration, isReadOnly bool) error { raw := db.DB() if raw == nil { - return sqlError.New("unable to get raw database object") + return newSQLError("unable to get raw database object") } if sqlDb != nil { @@ -919,15 +914,15 @@ func (ds *Plugin) openConnection(config *configuration, isReadOnly bool) error { } func (ds *Plugin) Close() error { - var errs errs.Group + var errs error if ds.db != nil { - errs.Add(ds.db.Close()) + errs = errors.Join(errs, ds.db.Close()) } if ds.roDb != nil { - errs.Add(ds.roDb.Close()) + errs = errors.Join(errs, ds.roDb.Close()) } - return errs.Err() + return errs } // withReadModifyWriteTx wraps the operation in a transaction appropriate for @@ -987,7 +982,7 @@ func (ds *Plugin) withTx(ctx context.Context, op func(tx *gorm.DB) error, readOn tx := db.BeginTx(ctx, nil) if err := tx.Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if err := op(tx); err != nil { @@ -999,9 +994,9 @@ func (ds *Plugin) withTx(ctx context.Context, op func(tx *gorm.DB) error, readOn // rolling back makes sure that functions that are invoked with // withReadTx, and then do writes, will not pass unit tests, since the // writes won't be committed. - return sqlError.Wrap(tx.Rollback().Error) + return newWrappedSQLError(tx.Rollback().Error) } - return sqlError.Wrap(tx.Commit().Error) + return newWrappedSQLError(tx.Commit().Error) } // gormToGRPCStatus takes an error, and converts it to a GRPC error. If the @@ -1020,7 +1015,8 @@ func (ds *Plugin) gormToGRPCStatus(err error) error { } code := codes.Unknown - if validationError.Has(err) { + var vErr *validationError + if errors.As(err, &vErr) { code = codes.InvalidArgument } @@ -1050,12 +1046,12 @@ func (ds *Plugin) openDB(cfg *configuration, isReadOnly bool) (*gorm.DB, string, logger: ds.log, } default: - return nil, "", false, nil, sqlError.New("unsupported database_type: %v", cfg.databaseTypeConfig.databaseType) + return nil, "", false, nil, newSQLError("unsupported database_type: %v", cfg.databaseTypeConfig.databaseType) } db, version, supportsCTE, err := dialect.connect(cfg, isReadOnly) if err != nil { - return nil, "", false, nil, sqlError.Wrap(err) + return nil, "", false, nil, newWrappedSQLError(err) } db.SetLogger(gormLogger{ @@ -1107,7 +1103,7 @@ func createBundle(tx *gorm.DB, bundle *common.Bundle) (*common.Bundle, error) { } if err := tx.Create(model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return bundle, nil @@ -1121,16 +1117,16 @@ func updateBundle(tx *gorm.DB, newBundle *common.Bundle, mask *common.BundleMask model := &Bundle{} if err := tx.Find(model, "trust_domain = ?", newModel.TrustDomain).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } model.Data, newBundle, err = applyBundleMask(model, newBundle, mask) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if err := tx.Save(model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return newBundle, nil @@ -1186,7 +1182,7 @@ func setBundle(tx *gorm.DB, b *common.Bundle) (*common.Bundle, error) { } return bundle, nil } else if result.Error != nil { - return nil, sqlError.Wrap(result.Error) + return nil, newWrappedSQLError(result.Error) } bundle, err := updateBundle(tx, b, nil) @@ -1212,7 +1208,7 @@ func appendBundle(tx *gorm.DB, b *common.Bundle) (*common.Bundle, error) { } return bundle, nil } else if result.Error != nil { - return nil, sqlError.Wrap(result.Error) + return nil, newWrappedSQLError(result.Error) } // parse the bundle data and add missing elements @@ -1230,7 +1226,7 @@ func appendBundle(tx *gorm.DB, b *common.Bundle) (*common.Bundle, error) { } model.Data = newModel.Data if err := tx.Save(model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } @@ -1240,14 +1236,14 @@ func appendBundle(tx *gorm.DB, b *common.Bundle) (*common.Bundle, error) { func deleteBundle(tx *gorm.DB, trustDomainID string, mode datastore.DeleteMode) error { model := new(Bundle) if err := tx.Find(model, "trust_domain = ?", trustDomainID).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } // Get a count of associated registration entries entriesAssociation := tx.Model(model).Association("FederatedEntries") entriesCount := entriesAssociation.Count() if err := entriesAssociation.Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if entriesCount > 0 { @@ -1261,11 +1257,11 @@ func deleteBundle(tx *gorm.DB, trustDomainID string, mode datastore.DeleteMode) federated_registration_entries WHERE bundle_id = ?)`), model.ID).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } case datastore.Dissociate: if err := entriesAssociation.Clear().Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } default: return status.Newf(codes.FailedPrecondition, "datastore-sql: cannot delete bundle; federated with %d registration entries", entriesCount).Err() @@ -1273,7 +1269,7 @@ func deleteBundle(tx *gorm.DB, trustDomainID string, mode datastore.DeleteMode) } if err := tx.Delete(model).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -1287,7 +1283,7 @@ func fetchBundle(tx *gorm.DB, trustDomainID string) (*common.Bundle, error) { case errors.Is(err, gorm.ErrRecordNotFound): return nil, nil case err != nil: - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } bundle, err := modelToBundle(model) @@ -1304,7 +1300,7 @@ func countBundles(tx *gorm.DB) (int32, error) { var count int if err := tx.Count(&count).Error; err != nil { - return 0, sqlError.Wrap(err) + return 0, newWrappedSQLError(err) } return int32(count), nil @@ -1327,7 +1323,7 @@ func listBundles(tx *gorm.DB, req *datastore.ListBundlesRequest) (*datastore.Lis var bundles []Bundle if err := tx.Find(&bundles).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if p != nil { @@ -1546,7 +1542,7 @@ func revokeJWTKey(tx *gorm.DB, trustDomainID string, authorityID string) (*commo func getBundle(tx *gorm.DB, trustDomainID string) (*common.Bundle, error) { model := &Bundle{} if err := tx.Find(model, "trust_domain = ?", trustDomainID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } bundle, err := modelToBundle(model) @@ -1569,7 +1565,7 @@ func createAttestedNode(tx *gorm.DB, node *common.AttestedNode) (*common.Atteste } if err := tx.Create(&model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToAttestedNode(model), nil @@ -1582,7 +1578,7 @@ func fetchAttestedNode(tx *gorm.DB, spiffeID string) (*common.AttestedNode, erro case errors.Is(err, gorm.ErrRecordNotFound): return nil, nil case err != nil: - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToAttestedNode(model), nil } @@ -1590,7 +1586,7 @@ func fetchAttestedNode(tx *gorm.DB, spiffeID string) (*common.AttestedNode, erro func countAttestedNodes(tx *gorm.DB) (int32, error) { var count int if err := tx.Model(&AttestedNode{}).Count(&count).Error; err != nil { - return 0, sqlError.Wrap(err) + return 0, newWrappedSQLError(err) } return int32(count), nil @@ -1705,7 +1701,7 @@ func createAttestedNodeEvent(tx *gorm.DB, event *datastore.AttestedNodeEvent) er }, SpiffeID: event.SpiffeID, }).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -1717,15 +1713,15 @@ func listAttestedNodeEvents(tx *gorm.DB, req *datastore.ListAttestedNodeEventsRe if req.GreaterThanEventID != 0 || req.LessThanEventID != 0 { query, id, err := buildListEventsQueryString(req.GreaterThanEventID, req.LessThanEventID) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if err := tx.Find(&events, query.String(), id).Order("id asc").Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } else { if err := tx.Find(&events).Order("id asc").Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } @@ -1742,7 +1738,7 @@ func listAttestedNodeEvents(tx *gorm.DB, req *datastore.ListAttestedNodeEventsRe func pruneAttestedNodeEvents(tx *gorm.DB, olderThan time.Duration) error { if err := tx.Where("created_at < ?", time.Now().Add(-olderThan)).Delete(&AttestedNodeEvent{}).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -1751,7 +1747,7 @@ func pruneAttestedNodeEvents(tx *gorm.DB, olderThan time.Duration) error { func fetchAttestedNodeEvent(db *sqlDB, eventID uint) (*datastore.AttestedNodeEvent, error) { event := AttestedNodeEvent{} if err := db.Find(&event, "id = ?", eventID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return &datastore.AttestedNodeEvent{ @@ -1766,7 +1762,7 @@ func deleteAttestedNodeEvent(tx *gorm.DB, eventID uint) error { ID: eventID, }, }).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -1805,12 +1801,12 @@ func filterNodesBySelectorSet(nodes []*common.AttestedNode, selectors []*common. func listAttestedNodesOnce(ctx context.Context, db *sqlDB, req *datastore.ListAttestedNodesRequest) (*datastore.ListAttestedNodesResponse, error) { query, args, err := buildListAttestedNodesQuery(db.databaseType, db.supportsCTE, req) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } rows, err := db.QueryContext(ctx, query, args...) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } defer rows.Close() @@ -1842,7 +1838,7 @@ func listAttestedNodesOnce(ctx context.Context, db *sqlDB, req *datastore.ListAt pushNode(node) if err := rows.Err(); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } resp := &datastore.ListAttestedNodesResponse{ @@ -1878,7 +1874,7 @@ func buildListAttestedNodesQuery(dbType string, supportsCTE bool, req *datastore } return buildListAttestedNodesQueryMySQL(req) default: - return "", nil, sqlError.New("unsupported db type: %q", dbType) + return "", nil, newSQLError("unsupported db type: %q", dbType) } } @@ -2022,7 +2018,7 @@ SELECT } builder.WriteString(query) if len(req.BySelectorMatch.Selectors) > 1 { - builder.WriteString(fmt.Sprintf(") c_%d\n", i)) + fmt.Fprintf(builder, ") c_%d\n", i) } // First subquery does not need USING(ID) if i > 0 { @@ -2041,7 +2037,7 @@ SELECT } } default: - return "", nil, errs.New("unhandled match behavior %q", req.BySelectorMatch.Match) + return "", nil, fmt.Errorf("unhandled match behavior %q", req.BySelectorMatch.Match) } // Add all selectors as arguments @@ -2206,11 +2202,11 @@ FROM attested_node_entries N builder.WriteString("\t\t\tINNER JOIN\n") builder.WriteString("\t\t\t(") builder.WriteString(query) - builder.WriteString(fmt.Sprintf(") c_%d\n", i+1)) + fmt.Fprintf(builder, ") c_%d\n", i+1) builder.WriteString("\t\t\tUSING(spiffe_id)\n") } default: - return "", nil, errs.New("unhandled match behavior %q", req.BySelectorMatch.Match) + return "", nil, fmt.Errorf("unhandled match behavior %q", req.BySelectorMatch.Match) } for _, selector := range req.BySelectorMatch.Selectors { @@ -2244,7 +2240,7 @@ FROM attested_node_entries N func updateAttestedNode(tx *gorm.DB, n *common.AttestedNode, mask *common.AttestedNodeMask) (*common.AttestedNode, error) { var model AttestedNode if err := tx.Find(&model, "spiffe_id = ?", n.SpiffeId).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if mask == nil { @@ -2268,7 +2264,7 @@ func updateAttestedNode(tx *gorm.DB, n *common.AttestedNode, mask *common.Attest updates["can_reattest"] = n.CanReattest } if err := tx.Model(&model).Updates(updates).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToAttestedNode(model), nil @@ -2282,15 +2278,15 @@ func deleteAttestedNodeAndSelectors(tx *gorm.DB, spiffeID string) (*common.Attes // batch delete all associated node selectors if err := tx.Where("spiffe_id = ?", spiffeID).Delete(&nodeSelectorModel).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if err := tx.Find(&nodeModel, "spiffe_id = ?", spiffeID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if err := tx.Delete(&nodeModel).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToAttestedNode(nodeModel), nil @@ -2310,11 +2306,11 @@ func setNodeSelectors(tx *gorm.DB, spiffeID string, selectors []*common.Selector // gap locks on the index. var ids []int64 if err := tx.Model(&NodeSelector{}).Where("spiffe_id = ?", spiffeID).Pluck("id", &ids).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if len(ids) > 0 { if err := tx.Where("id IN (?)", ids).Delete(&NodeSelector{}).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } } @@ -2325,7 +2321,7 @@ func setNodeSelectors(tx *gorm.DB, spiffeID string, selectors []*common.Selector Value: selector.Value, } if err := tx.Create(model).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } } @@ -2336,7 +2332,7 @@ func getNodeSelectors(ctx context.Context, db *sqlDB, spiffeID string) ([]*commo query := maybeRebind(db.databaseType, "SELECT type, value FROM node_resolver_map_entries WHERE spiffe_id=? ORDER BY id") rows, err := db.QueryContext(ctx, query, spiffeID) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } defer rows.Close() @@ -2344,13 +2340,13 @@ func getNodeSelectors(ctx context.Context, db *sqlDB, spiffeID string) ([]*commo for rows.Next() { selector := new(common.Selector) if err := rows.Scan(&selector.Type, &selector.Value); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } selectors = append(selectors, selector) } if err := rows.Err(); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return selectors, nil @@ -2361,7 +2357,7 @@ func listNodeSelectors(ctx context.Context, db *sqlDB, req *datastore.ListNodeSe query := maybeRebind(db.databaseType, rawQuery) rows, err := db.QueryContext(ctx, query, args...) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } defer rows.Close() @@ -2403,7 +2399,7 @@ func listNodeSelectors(ctx context.Context, db *sqlDB, req *datastore.ListNodeSe push("", nil) if err := rows.Err(); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return resp, nil @@ -2447,7 +2443,7 @@ func createRegistrationEntry(tx *gorm.DB, entry *common.RegistrationEntry) (*com } if err := tx.Create(&newRegisteredEntry).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } federatesWith, err := makeFederatesWith(tx, entry.FederatesWith) @@ -2467,7 +2463,7 @@ func createRegistrationEntry(tx *gorm.DB, entry *common.RegistrationEntry) (*com } if err := tx.Create(&newSelector).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } @@ -2478,7 +2474,7 @@ func createRegistrationEntry(tx *gorm.DB, entry *common.RegistrationEntry) (*com } if err := tx.Create(&newDNS).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } @@ -2493,12 +2489,12 @@ func createRegistrationEntry(tx *gorm.DB, entry *common.RegistrationEntry) (*com func fetchRegistrationEntry(ctx context.Context, db *sqlDB, entryID string) (*common.RegistrationEntry, error) { query, args, err := buildFetchRegistrationEntryQuery(db.databaseType, db.supportsCTE, entryID) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } rows, err := db.QueryContext(ctx, query, args...) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } defer rows.Close() @@ -2518,7 +2514,7 @@ func fetchRegistrationEntry(ctx context.Context, db *sqlDB, entryID string) (*co } if err := rows.Err(); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return entry, nil @@ -2540,7 +2536,7 @@ func buildFetchRegistrationEntryQuery(dbType string, supportsCTE bool, entryID s } return buildFetchRegistrationEntryQueryMySQL(entryID) default: - return "", nil, sqlError.New("unsupported db type: %q", dbType) + return "", nil, newSQLError("unsupported db type: %q", dbType) } } @@ -2857,12 +2853,12 @@ type queryContext interface { func listRegistrationEntriesOnce(ctx context.Context, db queryContext, databaseType string, supportsCTE bool, req *datastore.ListRegistrationEntriesRequest) (*datastore.ListRegistrationEntriesResponse, error) { query, args, err := buildListRegistrationEntriesQuery(databaseType, supportsCTE, req) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } rows, err := db.QueryContext(ctx, query, args...) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } defer rows.Close() entries := make([]*common.RegistrationEntry, 0, calculateResultPreallocation(req.Pagination)) @@ -2898,7 +2894,7 @@ func listRegistrationEntriesOnce(ctx context.Context, db queryContext, databaseT pushEntry(entry) if err := rows.Err(); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } resp := &datastore.ListRegistrationEntriesResponse{ @@ -2933,7 +2929,7 @@ func buildListRegistrationEntriesQuery(dbType string, supportsCTE bool, req *dat } return buildListRegistrationEntriesQueryMySQL(req) default: - return "", nil, sqlError.New("unsupported db type: %q", dbType) + return "", nil, newSQLError("unsupported db type: %q", dbType) } } @@ -3525,7 +3521,7 @@ func appendListRegistrationEntriesFilterQuery(filterExp string, builder *strings }) } default: - return false, nil, errs.New("unhandled selectors match behavior %q", req.BySelectors.Match) + return false, nil, fmt.Errorf("unhandled selectors match behavior %q", req.BySelectors.Match) } for _, selector := range req.BySelectors.Selectors { args = append(args, selector.Type, selector.Value) @@ -3598,7 +3594,7 @@ func appendListRegistrationEntriesFilterQuery(filterExp string, builder *strings args = append(args, len(trustDomains)) default: - return false, nil, errs.New("unhandled federates with match behavior %q", req.ByFederatesWith.Match) + return false, nil, fmt.Errorf("unhandled federates with match behavior %q", req.ByFederatesWith.Match) } root.children = append(root.children, filterNode) } @@ -3689,7 +3685,7 @@ type nodeRow struct { } func scanNodeRow(rs *sql.Rows, r *nodeRow) error { - return sqlError.Wrap(rs.Scan( + return newWrappedSQLError(rs.Scan( &r.EId, &r.SpiffeID, &r.DataType, @@ -3730,7 +3726,7 @@ func fillNodeFromRow(node *common.AttestedNode, r *nodeRow) error { if r.SelectorType.Valid { if !r.SelectorValue.Valid { - return sqlError.New("expected non-nil selector.value value for attested node %s", node.SpiffeId) + return newSQLError("expected non-nil selector.value value for attested node %s", node.SpiffeId) } node.Selectors = append(node.Selectors, &common.Selector{ Type: r.SelectorType.String, @@ -3752,7 +3748,7 @@ type nodeSelectorRow struct { } func scanNodeSelectorRow(rs *sql.Rows, r *nodeSelectorRow) error { - return sqlError.Wrap(rs.Scan( + return newWrappedSQLError(rs.Scan( &r.SpiffeID, &r.Type, &r.Value, @@ -3792,7 +3788,7 @@ type entryRow struct { } func scanEntryRow(rs *sql.Rows, r *entryRow) error { - return sqlError.Wrap(rs.Scan( + return newWrappedSQLError(rs.Scan( &r.EId, &r.EntryID, &r.SpiffeID, @@ -3842,7 +3838,7 @@ func fillEntryFromRow(entry *common.RegistrationEntry, r *entryRow) error { } if r.SelectorType.Valid { if !r.SelectorValue.Valid { - return sqlError.New("expected non-nil selector.value value for entry id %s", entry.EntryId) + return newSQLError("expected non-nil selector.value value for entry id %s", entry.EntryId) } entry.Selectors = append(entry.Selectors, &common.Selector{ Type: r.SelectorType.String, @@ -3896,7 +3892,7 @@ func updateRegistrationEntry(tx *gorm.DB, e *common.RegistrationEntry, mask *com // Get the existing entry entry := RegisteredEntry{} if err := tx.Find(&entry, "entry_id = ?", e.EntryId).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if mask == nil || mask.StoreSvid { entry.StoreSvid = e.StoreSvid @@ -3904,7 +3900,7 @@ func updateRegistrationEntry(tx *gorm.DB, e *common.RegistrationEntry, mask *com if mask == nil || mask.Selectors { // Delete existing selectors - we will write new ones if err := tx.Exec("DELETE FROM selectors WHERE registered_entry_id = ?", entry.ID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } selectors := []Selector{} @@ -3921,13 +3917,13 @@ func updateRegistrationEntry(tx *gorm.DB, e *common.RegistrationEntry, mask *com // Verify that final selectors contains the same 'type' when entry is used for store SVIDs if entry.StoreSvid && !equalSelectorTypes(entry.Selectors) { - return nil, validationError.New("invalid registration entry: selector types must be the same when store SVID is enabled") + return nil, newValidationError("invalid registration entry: selector types must be the same when store SVID is enabled") } if mask == nil || mask.DnsNames { // Delete existing DNSs - we will write new ones if err := tx.Exec("DELETE FROM dns_names WHERE registered_entry_id = ?", entry.ID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } dnsList := []DNSName{} @@ -3970,7 +3966,7 @@ func updateRegistrationEntry(tx *gorm.DB, e *common.RegistrationEntry, mask *com entry.RevisionNumber++ if err := tx.Save(&entry).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if mask == nil || mask.FederatesWith { @@ -3996,7 +3992,7 @@ func updateRegistrationEntry(tx *gorm.DB, e *common.RegistrationEntry, mask *com func deleteRegistrationEntry(tx *gorm.DB, entryID string) (*common.RegistrationEntry, error) { entry := RegisteredEntry{} if err := tx.Find(&entry, "entry_id = ?", entryID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } registrationEntry, err := modelToEntry(tx, entry) @@ -4018,17 +4014,17 @@ func deleteRegistrationEntrySupport(tx *gorm.DB, entry RegisteredEntry) error { } if err := tx.Delete(&entry).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } // Delete existing selectors if err := tx.Exec("DELETE FROM selectors WHERE registered_entry_id = ?", entry.ID).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } // Delete existing dns_names if err := tx.Exec("DELETE FROM dns_names WHERE registered_entry_id = ?", entry.ID).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4066,7 +4062,7 @@ func createRegistrationEntryEvent(tx *gorm.DB, event *datastore.RegistrationEntr }, EntryID: event.EntryID, }).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4075,7 +4071,7 @@ func createRegistrationEntryEvent(tx *gorm.DB, event *datastore.RegistrationEntr func fetchRegistrationEntryEvent(db *sqlDB, eventID uint) (*datastore.RegistrationEntryEvent, error) { event := RegisteredEntryEvent{} if err := db.Find(&event, "id = ?", eventID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return &datastore.RegistrationEntryEvent{ @@ -4090,7 +4086,7 @@ func deleteRegistrationEntryEvent(tx *gorm.DB, eventID uint) error { ID: eventID, }, }).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4102,15 +4098,15 @@ func listRegistrationEntryEvents(tx *gorm.DB, req *datastore.ListRegistrationEnt if req.GreaterThanEventID != 0 || req.LessThanEventID != 0 { query, id, err := buildListEventsQueryString(req.GreaterThanEventID, req.LessThanEventID) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if err := tx.Find(&events, query.String(), id).Order("id asc").Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } else { if err := tx.Find(&events).Order("id asc").Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } @@ -4127,7 +4123,7 @@ func listRegistrationEntryEvents(tx *gorm.DB, req *datastore.ListRegistrationEnt func pruneRegistrationEntryEvents(tx *gorm.DB, olderThan time.Duration) error { if err := tx.Where("created_at < ?", time.Now().Add(-olderThan)).Delete(&RegisteredEntryEvent{}).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4160,7 +4156,7 @@ func createJoinToken(tx *gorm.DB, token *datastore.JoinToken) error { } if err := tx.Create(&t).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4172,7 +4168,7 @@ func fetchJoinToken(tx *gorm.DB, token string) (*datastore.JoinToken, error) { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } else if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToJoinToken(model), nil @@ -4181,11 +4177,11 @@ func fetchJoinToken(tx *gorm.DB, token string) (*datastore.JoinToken, error) { func deleteJoinToken(tx *gorm.DB, token string) error { var model JoinToken if err := tx.Find(&model, "token = ?", token).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if err := tx.Delete(&model).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4193,7 +4189,7 @@ func deleteJoinToken(tx *gorm.DB, token string) error { func pruneJoinTokens(tx *gorm.DB, expiresBefore time.Time) error { if err := tx.Where("expiry < ?", expiresBefore.Unix()).Delete(&JoinToken{}).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4219,7 +4215,7 @@ func createFederationRelationship(tx *gorm.DB, fr *datastore.FederationRelations } if err := tx.Create(&model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return fr, nil @@ -4228,10 +4224,10 @@ func createFederationRelationship(tx *gorm.DB, fr *datastore.FederationRelations func deleteFederationRelationship(tx *gorm.DB, trustDomain spiffeid.TrustDomain) error { model := new(FederatedTrustDomain) if err := tx.Find(model, "trust_domain = ?", trustDomain.Name()).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if err := tx.Delete(model).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil } @@ -4243,7 +4239,7 @@ func fetchFederationRelationship(tx *gorm.DB, trustDomain spiffeid.TrustDomain) case errors.Is(err, gorm.ErrRecordNotFound): return nil, nil case err != nil: - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToFederationRelationship(tx, &model) @@ -4266,7 +4262,7 @@ func listFederationRelationships(tx *gorm.DB, req *datastore.ListFederationRelat var federationRelationships []FederatedTrustDomain if err := tx.Find(&federationRelationships).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if p != nil { @@ -4323,7 +4319,7 @@ func updateFederationRelationship(tx *gorm.DB, fr *datastore.FederationRelations } if err := tx.Save(&model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToFederationRelationship(tx, &model) @@ -4365,7 +4361,7 @@ func modelToFederationRelationship(tx *gorm.DB, model *FederatedTrustDomain) (*d td, err := spiffeid.TrustDomainFromString(model.TrustDomain) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } fr := &datastore.FederationRelationship{ @@ -4400,7 +4396,7 @@ func modelToFederationRelationship(tx *gorm.DB, model *FederatedTrustDomain) (*d func modelToBundle(model *Bundle) (*common.Bundle, error) { bundle := new(common.Bundle) if err := proto.Unmarshal(model.Data, bundle); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return bundle, nil @@ -4408,11 +4404,11 @@ func modelToBundle(model *Bundle) (*common.Bundle, error) { func validateRegistrationEntry(entry *common.RegistrationEntry) error { if entry == nil { - return validationError.New("invalid request: missing registered entry") + return newValidationError("invalid request: missing registered entry") } if len(entry.Selectors) == 0 { - return validationError.New("invalid registration entry: missing selector list") + return newValidationError("invalid registration entry: missing selector list") } // In case of StoreSvid is set, all entries 'must' be the same type, @@ -4423,31 +4419,31 @@ func validateRegistrationEntry(entry *common.RegistrationEntry) error { tpe := entry.Selectors[0].Type for _, t := range entry.Selectors { if tpe != t.Type { - return validationError.New("invalid registration entry: selector types must be the same when store SVID is enabled") + return newValidationError("invalid registration entry: selector types must be the same when store SVID is enabled") } } } if len(entry.EntryId) > 255 { - return validationError.New("invalid registration entry: entry ID too long") + return newValidationError("invalid registration entry: entry ID too long") } for _, e := range entry.EntryId { if !unicode.In(e, validEntryIDChars) { - return validationError.New("invalid registration entry: entry ID contains invalid characters") + return newValidationError("invalid registration entry: entry ID contains invalid characters") } } if len(entry.SpiffeId) == 0 { - return validationError.New("invalid registration entry: missing SPIFFE ID") + return newValidationError("invalid registration entry: missing SPIFFE ID") } if entry.X509SvidTtl < 0 { - return validationError.New("invalid registration entry: X509SvidTtl is not set") + return newValidationError("invalid registration entry: X509SvidTtl is not set") } if entry.JwtSvidTtl < 0 { - return validationError.New("invalid registration entry: JwtSvidTtl is not set") + return newValidationError("invalid registration entry: JwtSvidTtl is not set") } return nil @@ -4469,26 +4465,26 @@ func equalSelectorTypes(selectors []Selector) bool { func validateRegistrationEntryForUpdate(entry *common.RegistrationEntry, mask *common.RegistrationEntryMask) error { if entry == nil { - return validationError.New("invalid request: missing registered entry") + return newValidationError("invalid request: missing registered entry") } if (mask == nil || mask.Selectors) && len(entry.Selectors) == 0 { - return validationError.New("invalid registration entry: missing selector list") + return newValidationError("invalid registration entry: missing selector list") } if (mask == nil || mask.SpiffeId) && entry.SpiffeId == "" { - return validationError.New("invalid registration entry: missing SPIFFE ID") + return newValidationError("invalid registration entry: missing SPIFFE ID") } if (mask == nil || mask.X509SvidTtl) && (entry.X509SvidTtl < 0) { - return validationError.New("invalid registration entry: X509SvidTtl is not set") + return newValidationError("invalid registration entry: X509SvidTtl is not set") } if (mask == nil || mask.JwtSvidTtl) && (entry.JwtSvidTtl < 0) { - return validationError.New("invalid registration entry: JwtSvidTtl is not set") + return newValidationError("invalid registration entry: JwtSvidTtl is not set") } return nil @@ -4498,11 +4494,11 @@ func validateRegistrationEntryForUpdate(entry *common.RegistrationEntry, mask *c // performs validation, and fully parses certificates to form CACert embedded models. func bundleToModel(pb *common.Bundle) (*Bundle, error) { if pb == nil { - return nil, sqlError.New("missing bundle in request") + return nil, newSQLError("missing bundle in request") } data, err := proto.Marshal(pb) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return &Bundle{ @@ -4514,7 +4510,7 @@ func bundleToModel(pb *common.Bundle) (*Bundle, error) { func modelToEntry(tx *gorm.DB, model RegisteredEntry) (*common.RegistrationEntry, error) { var fetchedSelectors []*Selector if err := tx.Model(&model).Related(&fetchedSelectors).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } selectors := make([]*common.Selector, 0, len(fetchedSelectors)) @@ -4527,7 +4523,7 @@ func modelToEntry(tx *gorm.DB, model RegisteredEntry) (*common.RegistrationEntry var fetchedDNSs []*DNSName if err := tx.Model(&model).Related(&fetchedDNSs).Order("registered_entry_id ASC").Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } var dnsList []string @@ -4540,7 +4536,7 @@ func modelToEntry(tx *gorm.DB, model RegisteredEntry) (*common.RegistrationEntry var fetchedBundles []*Bundle if err := tx.Model(&model).Association("FederatesWith").Find(&fetchedBundles).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } var federatesWith []string @@ -4655,11 +4651,11 @@ func bindVarsFn(fn func(int) string, query string) string { func (cfg *configuration) Validate() error { if cfg.databaseTypeConfig.databaseType == "" { - return sqlError.New("database_type must be set") + return newSQLError("database_type must be set") } if cfg.ConnectionString == "" { - return sqlError.New("connection_string must be set") + return newSQLError("connection_string must be set") } if isMySQLDbType(cfg.databaseTypeConfig.databaseType) { @@ -4701,12 +4697,12 @@ func getConnectionString(cfg *configuration, isReadOnly bool) string { func queryVersion(gormDB *gorm.DB, query string) (string, error) { db := gormDB.DB() if db == nil { - return "", sqlError.New("unable to get raw database object") + return "", newSQLError("unable to get raw database object") } var version string if err := db.QueryRow(query).Scan(&version); err != nil { - return "", sqlError.Wrap(err) + return "", newWrappedSQLError(err) } return version, nil } @@ -4762,7 +4758,7 @@ func createCAJournal(tx *gorm.DB, caJournal *datastore.CAJournal) (*datastore.CA } if err := tx.Create(&model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToCAJournal(model), nil @@ -4775,7 +4771,7 @@ func fetchCAJournal(tx *gorm.DB, activeX509AuthorityID string) (*datastore.CAJou case errors.Is(err, gorm.ErrRecordNotFound): return nil, nil case err != nil: - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToCAJournal(model), nil @@ -4784,7 +4780,7 @@ func fetchCAJournal(tx *gorm.DB, activeX509AuthorityID string) (*datastore.CAJou func listCAJournalsForTesting(tx *gorm.DB) (caJournals []*datastore.CAJournal, err error) { var caJournalsModel []CAJournal if err := tx.Find(&caJournals).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } for _, model := range caJournalsModel { @@ -4797,14 +4793,14 @@ func listCAJournalsForTesting(tx *gorm.DB) (caJournals []*datastore.CAJournal, e func updateCAJournal(tx *gorm.DB, caJournal *datastore.CAJournal) (*datastore.CAJournal, error) { var model CAJournal if err := tx.Find(&model, "id = ?", caJournal.ID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } model.ActiveX509AuthorityID = caJournal.ActiveX509AuthorityID model.Data = caJournal.Data if err := tx.Save(&model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToCAJournal(model), nil @@ -4821,10 +4817,10 @@ func validateCAJournal(caJournal *datastore.CAJournal) error { func deleteCAJournal(tx *gorm.DB, caJournalID uint) error { model := new(CAJournal) if err := tx.Find(model, "id = ?", caJournalID).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if err := tx.Delete(model).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil } diff --git a/pkg/server/datastore/sqlstore/sqlstore_test.go b/pkg/server/datastore/sqlstore/sqlstore_test.go index e0d11ab020..f13beb4d51 100644 --- a/pkg/server/datastore/sqlstore/sqlstore_test.go +++ b/pkg/server/datastore/sqlstore/sqlstore_test.go @@ -250,6 +250,7 @@ func (s *PluginSuite) TestBundleCRUD() { // fetch non-existent fb, err := s.ds.FetchBundle(ctx, "spiffe://foo") + s.T().Logf("err type: %T", err) s.Require().NoError(err) s.Require().Nil(fb) @@ -421,7 +422,8 @@ func (s *PluginSuite) TestListBundlesWithPagination() { PageSize: 2, }, expectedList: []*common.Bundle{bundle1, bundle2}, - expectedPagination: &datastore.Pagination{Token: "2", + expectedPagination: &datastore.Pagination{ + Token: "2", PageSize: 2, }, }, @@ -2858,8 +2860,8 @@ func (s *PluginSuite) testListRegistrationEntries(dataConsistency datastore.Data } var tokensIn []string - var actualEntriesOut = make(map[string]*common.RegistrationEntry) - var expectedEntriesOut = make(map[string]*common.RegistrationEntry) + actualEntriesOut := make(map[string]*common.RegistrationEntry) + expectedEntriesOut := make(map[string]*common.RegistrationEntry) req := &datastore.ListRegistrationEntriesRequest{ Pagination: pagination, ByParentID: tt.byParentID, @@ -3095,111 +3097,160 @@ func (s *PluginSuite) TestUpdateRegistrationEntryWithMask() { result func(*common.RegistrationEntry) err error }{ // SPIFFE ID FIELD -- this field is validated so we check with good and bad data - {name: "Update Spiffe ID, Good Data, Mask True", + { + name: "Update Spiffe ID, Good Data, Mask True", mask: &common.RegistrationEntryMask{SpiffeId: true}, update: func(e *common.RegistrationEntry) { e.SpiffeId = newEntry.SpiffeId }, - result: func(e *common.RegistrationEntry) { e.SpiffeId = newEntry.SpiffeId }}, - {name: "Update Spiffe ID, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.SpiffeId = newEntry.SpiffeId }, + }, + { + name: "Update Spiffe ID, Good Data, Mask False", mask: &common.RegistrationEntryMask{SpiffeId: false}, update: func(e *common.RegistrationEntry) { e.SpiffeId = newEntry.SpiffeId }, - result: func(e *common.RegistrationEntry) {}}, - {name: "Update Spiffe ID, Bad Data, Mask True", + result: func(e *common.RegistrationEntry) {}, + }, + { + name: "Update Spiffe ID, Bad Data, Mask True", mask: &common.RegistrationEntryMask{SpiffeId: true}, update: func(e *common.RegistrationEntry) { e.SpiffeId = badEntry.SpiffeId }, - err: errors.New("invalid registration entry: missing SPIFFE ID")}, - {name: "Update Spiffe ID, Bad Data, Mask False", + err: errors.New("invalid registration entry: missing SPIFFE ID"), + }, + { + name: "Update Spiffe ID, Bad Data, Mask False", mask: &common.RegistrationEntryMask{SpiffeId: false}, update: func(e *common.RegistrationEntry) { e.SpiffeId = badEntry.SpiffeId }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // PARENT ID FIELD -- This field isn't validated so we just check with good data - {name: "Update Parent ID, Good Data, Mask True", + { + name: "Update Parent ID, Good Data, Mask True", mask: &common.RegistrationEntryMask{ParentId: true}, update: func(e *common.RegistrationEntry) { e.ParentId = newEntry.ParentId }, - result: func(e *common.RegistrationEntry) { e.ParentId = newEntry.ParentId }}, - {name: "Update Parent ID, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.ParentId = newEntry.ParentId }, + }, + { + name: "Update Parent ID, Good Data, Mask False", mask: &common.RegistrationEntryMask{ParentId: false}, update: func(e *common.RegistrationEntry) { e.ParentId = newEntry.ParentId }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // X509 SVID TTL FIELD -- This field is validated so we check with good and bad data - {name: "Update X509 SVID TTL, Good Data, Mask True", + { + name: "Update X509 SVID TTL, Good Data, Mask True", mask: &common.RegistrationEntryMask{X509SvidTtl: true}, update: func(e *common.RegistrationEntry) { e.X509SvidTtl = newEntry.X509SvidTtl }, - result: func(e *common.RegistrationEntry) { e.X509SvidTtl = newEntry.X509SvidTtl }}, - {name: "Update X509 SVID TTL, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.X509SvidTtl = newEntry.X509SvidTtl }, + }, + { + name: "Update X509 SVID TTL, Good Data, Mask False", mask: &common.RegistrationEntryMask{X509SvidTtl: false}, update: func(e *common.RegistrationEntry) { e.X509SvidTtl = badEntry.X509SvidTtl }, - result: func(e *common.RegistrationEntry) {}}, - {name: "Update X509 SVID TTL, Bad Data, Mask True", + result: func(e *common.RegistrationEntry) {}, + }, + { + name: "Update X509 SVID TTL, Bad Data, Mask True", mask: &common.RegistrationEntryMask{X509SvidTtl: true}, update: func(e *common.RegistrationEntry) { e.X509SvidTtl = badEntry.X509SvidTtl }, - err: errors.New("invalid registration entry: X509SvidTtl is not set")}, - {name: "Update X509 SVID TTL, Bad Data, Mask False", + err: errors.New("invalid registration entry: X509SvidTtl is not set"), + }, + { + name: "Update X509 SVID TTL, Bad Data, Mask False", mask: &common.RegistrationEntryMask{X509SvidTtl: false}, update: func(e *common.RegistrationEntry) { e.X509SvidTtl = badEntry.X509SvidTtl }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // JWT SVID TTL FIELD -- This field is validated so we check with good and bad data - {name: "Update JWT SVID TTL, Good Data, Mask True", + { + name: "Update JWT SVID TTL, Good Data, Mask True", mask: &common.RegistrationEntryMask{JwtSvidTtl: true}, update: func(e *common.RegistrationEntry) { e.JwtSvidTtl = newEntry.JwtSvidTtl }, - result: func(e *common.RegistrationEntry) { e.JwtSvidTtl = newEntry.JwtSvidTtl }}, - {name: "Update JWT SVID TTL, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.JwtSvidTtl = newEntry.JwtSvidTtl }, + }, + { + name: "Update JWT SVID TTL, Good Data, Mask False", mask: &common.RegistrationEntryMask{JwtSvidTtl: false}, update: func(e *common.RegistrationEntry) { e.JwtSvidTtl = badEntry.JwtSvidTtl }, - result: func(e *common.RegistrationEntry) {}}, - {name: "Update JWT SVID TTL, Bad Data, Mask True", + result: func(e *common.RegistrationEntry) {}, + }, + { + name: "Update JWT SVID TTL, Bad Data, Mask True", mask: &common.RegistrationEntryMask{JwtSvidTtl: true}, update: func(e *common.RegistrationEntry) { e.JwtSvidTtl = badEntry.JwtSvidTtl }, - err: errors.New("invalid registration entry: JwtSvidTtl is not set")}, - {name: "Update JWT SVID TTL, Bad Data, Mask False", + err: errors.New("invalid registration entry: JwtSvidTtl is not set"), + }, + { + name: "Update JWT SVID TTL, Bad Data, Mask False", mask: &common.RegistrationEntryMask{JwtSvidTtl: false}, update: func(e *common.RegistrationEntry) { e.JwtSvidTtl = badEntry.JwtSvidTtl }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // SELECTORS FIELD -- This field is validated so we check with good and bad data - {name: "Update Selectors, Good Data, Mask True", + { + name: "Update Selectors, Good Data, Mask True", mask: &common.RegistrationEntryMask{Selectors: true}, update: func(e *common.RegistrationEntry) { e.Selectors = newEntry.Selectors }, - result: func(e *common.RegistrationEntry) { e.Selectors = newEntry.Selectors }}, - {name: "Update Selectors, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.Selectors = newEntry.Selectors }, + }, + { + name: "Update Selectors, Good Data, Mask False", mask: &common.RegistrationEntryMask{Selectors: false}, update: func(e *common.RegistrationEntry) { e.Selectors = badEntry.Selectors }, - result: func(e *common.RegistrationEntry) {}}, - {name: "Update Selectors, Bad Data, Mask True", + result: func(e *common.RegistrationEntry) {}, + }, + { + name: "Update Selectors, Bad Data, Mask True", mask: &common.RegistrationEntryMask{Selectors: true}, update: func(e *common.RegistrationEntry) { e.Selectors = badEntry.Selectors }, - err: errors.New("invalid registration entry: missing selector list")}, - {name: "Update Selectors, Bad Data, Mask False", + err: errors.New("invalid registration entry: missing selector list"), + }, + { + name: "Update Selectors, Bad Data, Mask False", mask: &common.RegistrationEntryMask{Selectors: false}, update: func(e *common.RegistrationEntry) { e.Selectors = badEntry.Selectors }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // FEDERATESWITH FIELD -- This field isn't validated so we just check with good data - {name: "Update FederatesWith, Good Data, Mask True", + { + name: "Update FederatesWith, Good Data, Mask True", mask: &common.RegistrationEntryMask{FederatesWith: true}, update: func(e *common.RegistrationEntry) { e.FederatesWith = newEntry.FederatesWith }, - result: func(e *common.RegistrationEntry) { e.FederatesWith = newEntry.FederatesWith }}, - {name: "Update FederatesWith Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.FederatesWith = newEntry.FederatesWith }, + }, + { + name: "Update FederatesWith Good Data, Mask False", mask: &common.RegistrationEntryMask{FederatesWith: false}, update: func(e *common.RegistrationEntry) { e.FederatesWith = newEntry.FederatesWith }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // ADMIN FIELD -- This field isn't validated so we just check with good data - {name: "Update Admin, Good Data, Mask True", + { + name: "Update Admin, Good Data, Mask True", mask: &common.RegistrationEntryMask{Admin: true}, update: func(e *common.RegistrationEntry) { e.Admin = newEntry.Admin }, - result: func(e *common.RegistrationEntry) { e.Admin = newEntry.Admin }}, - {name: "Update Admin, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.Admin = newEntry.Admin }, + }, + { + name: "Update Admin, Good Data, Mask False", mask: &common.RegistrationEntryMask{Admin: false}, update: func(e *common.RegistrationEntry) { e.Admin = newEntry.Admin }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // STORESVID FIELD -- This field isn't validated so we just check with good data - {name: "Update StoreSvid, Good Data, Mask True", + { + name: "Update StoreSvid, Good Data, Mask True", mask: &common.RegistrationEntryMask{StoreSvid: true}, update: func(e *common.RegistrationEntry) { e.StoreSvid = newEntry.StoreSvid }, - result: func(e *common.RegistrationEntry) { e.StoreSvid = newEntry.StoreSvid }}, - {name: "Update StoreSvid, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.StoreSvid = newEntry.StoreSvid }, + }, + { + name: "Update StoreSvid, Good Data, Mask False", mask: &common.RegistrationEntryMask{Admin: false}, update: func(e *common.RegistrationEntry) { e.StoreSvid = newEntry.StoreSvid }, - result: func(e *common.RegistrationEntry) {}}, - {name: "Update StoreSvid, Invalid selectors, Mask True", + result: func(e *common.RegistrationEntry) {}, + }, + { + name: "Update StoreSvid, Invalid selectors, Mask True", mask: &common.RegistrationEntryMask{StoreSvid: true, Selectors: true}, update: func(e *common.RegistrationEntry) { e.StoreSvid = newEntry.StoreSvid @@ -3208,50 +3259,68 @@ func (s *PluginSuite) TestUpdateRegistrationEntryWithMask() { {Type: "Type2", Value: "Value2"}, } }, - err: validationError.New("invalid registration entry: selector types must be the same when store SVID is enabled"), + err: newValidationError("invalid registration entry: selector types must be the same when store SVID is enabled"), }, // ENTRYEXPIRY FIELD -- This field isn't validated so we just check with good data - {name: "Update EntryExpiry, Good Data, Mask True", + { + name: "Update EntryExpiry, Good Data, Mask True", mask: &common.RegistrationEntryMask{EntryExpiry: true}, update: func(e *common.RegistrationEntry) { e.EntryExpiry = newEntry.EntryExpiry }, - result: func(e *common.RegistrationEntry) { e.EntryExpiry = newEntry.EntryExpiry }}, - {name: "Update EntryExpiry, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.EntryExpiry = newEntry.EntryExpiry }, + }, + { + name: "Update EntryExpiry, Good Data, Mask False", mask: &common.RegistrationEntryMask{EntryExpiry: false}, update: func(e *common.RegistrationEntry) { e.EntryExpiry = newEntry.EntryExpiry }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // DNSNAMES FIELD -- This field isn't validated so we just check with good data - {name: "Update DnsNames, Good Data, Mask True", + { + name: "Update DnsNames, Good Data, Mask True", mask: &common.RegistrationEntryMask{DnsNames: true}, update: func(e *common.RegistrationEntry) { e.DnsNames = newEntry.DnsNames }, - result: func(e *common.RegistrationEntry) { e.DnsNames = newEntry.DnsNames }}, - {name: "Update DnsNames, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.DnsNames = newEntry.DnsNames }, + }, + { + name: "Update DnsNames, Good Data, Mask False", mask: &common.RegistrationEntryMask{DnsNames: false}, update: func(e *common.RegistrationEntry) { e.DnsNames = newEntry.DnsNames }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // DOWNSTREAM FIELD -- This field isn't validated so we just check with good data - {name: "Update DnsNames, Good Data, Mask True", + { + name: "Update DnsNames, Good Data, Mask True", mask: &common.RegistrationEntryMask{Downstream: true}, update: func(e *common.RegistrationEntry) { e.Downstream = newEntry.Downstream }, - result: func(e *common.RegistrationEntry) { e.Downstream = newEntry.Downstream }}, - {name: "Update DnsNames, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.Downstream = newEntry.Downstream }, + }, + { + name: "Update DnsNames, Good Data, Mask False", mask: &common.RegistrationEntryMask{Downstream: false}, update: func(e *common.RegistrationEntry) { e.Downstream = newEntry.Downstream }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // HINT -- This field isn't validated so we just check with good data - {name: "Update Hint, Good Data, Mask True", + { + name: "Update Hint, Good Data, Mask True", mask: &common.RegistrationEntryMask{Hint: true}, update: func(e *common.RegistrationEntry) { e.Hint = newEntry.Hint }, - result: func(e *common.RegistrationEntry) { e.Hint = newEntry.Hint }}, - {name: "Update Hint, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.Hint = newEntry.Hint }, + }, + { + name: "Update Hint, Good Data, Mask False", mask: &common.RegistrationEntryMask{Hint: false}, update: func(e *common.RegistrationEntry) { e.Hint = newEntry.Hint }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // This should update all fields - {name: "Test With Nil Mask", + { + name: "Test With Nil Mask", mask: nil, update: func(e *common.RegistrationEntry) { proto.Merge(e, oldEntry) }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, } { tt := testcase s.Run(tt.name, func() { @@ -3350,7 +3419,6 @@ func (s *PluginSuite) TestListParentIDEntries() { expectedList []*common.RegistrationEntry }{ { - name: "test_parentID_found", registrationEntries: allEntries, parentID: "spiffe://parent", @@ -4627,7 +4695,8 @@ func (s *PluginSuite) TestListFederationRelationships() { PageSize: 2, }, expectedList: []*datastore.FederationRelationship{fr1, fr2}, - expectedPagination: &datastore.Pagination{Token: "2", + expectedPagination: &datastore.Pagination{ + Token: "2", PageSize: 2, }, }, diff --git a/pkg/server/datastore/sqlstore/stmt_cache.go b/pkg/server/datastore/sqlstore/stmt_cache.go index f3fb354140..a934d2a880 100644 --- a/pkg/server/datastore/sqlstore/stmt_cache.go +++ b/pkg/server/datastore/sqlstore/stmt_cache.go @@ -25,7 +25,7 @@ func (cache *stmtCache) get(ctx context.Context, query string) (*sql.Stmt, error stmt, err := cache.db.PrepareContext(ctx, query) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } value, loaded = cache.stmts.LoadOrStore(query, stmt) if loaded { diff --git a/pkg/server/endpoints/bundle/acme_auth.go b/pkg/server/endpoints/bundle/acme_auth.go index a9d12c5bcc..45e5fbb72b 100644 --- a/pkg/server/endpoints/bundle/acme_auth.go +++ b/pkg/server/endpoints/bundle/acme_auth.go @@ -4,12 +4,12 @@ import ( "context" "crypto" "crypto/tls" + "fmt" "github.com/sirupsen/logrus" "github.com/spiffe/spire/pkg/common/version" "github.com/spiffe/spire/pkg/server/endpoints/bundle/internal/autocert" "github.com/spiffe/spire/pkg/server/plugin/keymanager" - "github.com/zeebo/errs" "golang.org/x/crypto/acme" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -122,7 +122,7 @@ func (ks *acmeKeyStore) NewPrivateKey(ctx context.Context, id string, keyType au case autocert.EC256: kmKeyType = keymanager.ECP256 default: - return nil, errs.New("unsupported key type: %d", keyType) + return nil, fmt.Errorf("unsupported key type: %d", keyType) } key, err := ks.km.GenerateKey(ctx, keyID, kmKeyType) diff --git a/pkg/server/endpoints/bundle/server.go b/pkg/server/endpoints/bundle/server.go index d96490e476..e9c7a39bdf 100644 --- a/pkg/server/endpoints/bundle/server.go +++ b/pkg/server/endpoints/bundle/server.go @@ -11,7 +11,6 @@ import ( "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" "github.com/spiffe/spire/pkg/common/bundleutil" - "github.com/zeebo/errs" ) type Getter interface { @@ -57,7 +56,7 @@ func (s *Server) ListenAndServe(ctx context.Context) error { // it gives us the ability to use/inspect an ephemeral port during testing. listener, err := s.c.listen("tcp", s.c.Address) if err != nil { - return errs.Wrap(err) + return err } // Set up the TLS config, setting TLS 1.2 as the minimum. @@ -72,7 +71,7 @@ func (s *Server) ListenAndServe(ctx context.Context) error { errCh := make(chan error, 1) go func() { - errCh <- errs.Wrap(server.ServeTLS(listener, "", "")) + errCh <- server.ServeTLS(listener, "", "") }() select { diff --git a/pkg/server/hostservice/identityprovider/identityprovider.go b/pkg/server/hostservice/identityprovider/identityprovider.go index 79213beff8..7aaf243c36 100644 --- a/pkg/server/hostservice/identityprovider/identityprovider.go +++ b/pkg/server/hostservice/identityprovider/identityprovider.go @@ -13,7 +13,6 @@ import ( "github.com/spiffe/spire/pkg/common/coretypes/jwtkey" "github.com/spiffe/spire/pkg/common/coretypes/x509certificate" "github.com/spiffe/spire/pkg/server/datastore" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -124,7 +123,7 @@ func (v1 *identityProviderV1) FetchX509Identity(ctx context.Context, _ *identity privateKey, err := x509.MarshalPKCS8PrivateKey(x509Identity.PrivateKey) if err != nil { - return nil, errs.Wrap(err) + return nil, err } return &identityproviderv1.FetchX509IdentityResponse{ diff --git a/support/oidc-discovery-provider/config.go b/support/oidc-discovery-provider/config.go index 7df3eaa615..36932449ae 100644 --- a/support/oidc-discovery-provider/config.go +++ b/support/oidc-discovery-provider/config.go @@ -1,13 +1,14 @@ package main import ( + "errors" + "fmt" "net" "net/url" "os" "time" "github.com/hashicorp/hcl" - "github.com/zeebo/errs" ) const ( @@ -188,7 +189,7 @@ type experimentalWorkloadAPIConfig struct { func LoadConfig(path string) (*Config, error) { hclBytes, err := os.ReadFile(path) if err != nil { - return nil, errs.New("unable to load configuration: %v", err) + return nil, fmt.Errorf("unable to load configuration: %v", err) } return ParseConfig(string(hclBytes)) } @@ -196,7 +197,7 @@ func LoadConfig(path string) (*Config, error) { func ParseConfig(hclConfig string) (_ *Config, err error) { c := new(Config) if err := hcl.Decode(c, hclConfig); err != nil { - return nil, errs.New("unable to decode configuration: %v", err) + return nil, fmt.Errorf("unable to decode configuration: %v", err) } if c.LogLevel == "" { @@ -204,7 +205,7 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { } if len(c.Domains) == 0 { - return nil, errs.New("at least one domain must be configured") + return nil, errors.New("at least one domain must be configured") } c.Domains = dedupeList(c.Domains) @@ -215,20 +216,20 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { } switch { case c.InsecureAddr != "": - return nil, errs.New("insecure_addr and the acme section are mutually exclusive") + return nil, errors.New("insecure_addr and the acme section are mutually exclusive") case !c.ACME.ToSAccepted: - return nil, errs.New("tos_accepted must be set to true in the acme configuration section") + return nil, errors.New("tos_accepted must be set to true in the acme configuration section") case c.ACME.Email == "": - return nil, errs.New("email must be configured in the acme configuration section") + return nil, errors.New("email must be configured in the acme configuration section") } } if c.ServingCertFile != nil { if c.ServingCertFile.CertFilePath == "" { - return nil, errs.New("cert_file_path must be configured in the serving_cert_file configuration section") + return nil, errors.New("cert_file_path must be configured in the serving_cert_file configuration section") } if c.ServingCertFile.KeyFilePath == "" { - return nil, errs.New("key_file_path must be configured in the serving_cert_file configuration section") + return nil, errors.New("key_file_path must be configured in the serving_cert_file configuration section") } if c.ServingCertFile.RawAddr == "" { @@ -237,13 +238,13 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { addr, err := net.ResolveTCPAddr("tcp", c.ServingCertFile.RawAddr) if err != nil { - return nil, errs.New("invalid addr in the serving_cert_file configuration section: %v", err) + return nil, fmt.Errorf("invalid addr in the serving_cert_file configuration section: %v", err) } c.ServingCertFile.Addr = addr c.ServingCertFile.FileSyncInterval, err = parseDurationField(c.ServingCertFile.RawFileSyncInterval, defaultFileSyncInterval) if err != nil { - return nil, errs.New("invalid file_sync_interval in the serving_cert_file configuration section: %v", err) + return nil, fmt.Errorf("invalid file_sync_interval in the serving_cert_file configuration section: %v", err) } } @@ -252,18 +253,18 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { if c.ServerAPI != nil { c.ServerAPI.PollInterval, err = parseDurationField(c.ServerAPI.RawPollInterval, defaultPollInterval) if err != nil { - return nil, errs.New("invalid poll_interval in the server_api configuration section: %v", err) + return nil, fmt.Errorf("invalid poll_interval in the server_api configuration section: %v", err) } methodCount++ } if c.WorkloadAPI != nil { if c.WorkloadAPI.TrustDomain == "" { - return nil, errs.New("trust_domain must be configured in the workload_api configuration section") + return nil, errors.New("trust_domain must be configured in the workload_api configuration section") } c.WorkloadAPI.PollInterval, err = parseDurationField(c.WorkloadAPI.RawPollInterval, defaultPollInterval) if err != nil { - return nil, errs.New("invalid poll_interval in the workload_api configuration section: %v", err) + return nil, fmt.Errorf("invalid poll_interval in the workload_api configuration section: %v", err) } methodCount++ } @@ -286,15 +287,15 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { switch methodCount { case 0: - return nil, errs.New("either the server_api or workload_api section must be configured") + return nil, errors.New("either the server_api or workload_api section must be configured") case 1: default: - return nil, errs.New("the server_api and workload_api sections are mutually exclusive") + return nil, errors.New("the server_api and workload_api sections are mutually exclusive") } if c.JWTIssuer != "" { jwtIssuer, err := url.Parse(c.JWTIssuer) if err != nil || jwtIssuer.Scheme == "" || jwtIssuer.Host == "" { - return nil, errs.New("the jwt_issuer url could not be parsed") + return nil, errors.New("the jwt_issuer url could not be parsed") } } return c, nil diff --git a/support/oidc-discovery-provider/main.go b/support/oidc-discovery-provider/main.go index 0e65d9cb68..452bc8a502 100644 --- a/support/oidc-discovery-provider/main.go +++ b/support/oidc-discovery-provider/main.go @@ -3,6 +3,7 @@ package main import ( "context" "crypto/tls" + "errors" "flag" "fmt" "net" @@ -17,7 +18,6 @@ import ( "github.com/spiffe/spire/pkg/common/log" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/pkg/common/version" - "github.com/zeebo/errs" "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" ) @@ -49,7 +49,7 @@ func run(configPath string) error { log, err := log.NewLogger(log.WithLevel(config.LogLevel), log.WithFormat(config.LogFormat), log.WithOutputFile(config.LogPath)) if err != nil { - return errs.Wrap(err) + return err } defer log.Close() @@ -157,7 +157,7 @@ func newSource(log logrus.FieldLogger, config *Config) (JWKSSource, error) { case config.WorkloadAPI != nil: workloadAPIAddr, err := config.getWorkloadAPIAddr() if err != nil { - return nil, errs.Wrap(err) + return nil, err } return NewWorkloadAPISource(WorkloadAPISourceConfig{ Log: log, @@ -167,7 +167,7 @@ func newSource(log logrus.FieldLogger, config *Config) (JWKSSource, error) { }) default: // This is defensive; LoadConfig should prevent this from happening. - return nil, errs.New("no source has been configured") + return nil, errors.New("no source has been configured") } } diff --git a/support/oidc-discovery-provider/main_posix.go b/support/oidc-discovery-provider/main_posix.go index d61d1091a6..4e6e75cee8 100644 --- a/support/oidc-discovery-provider/main_posix.go +++ b/support/oidc-discovery-provider/main_posix.go @@ -3,12 +3,12 @@ package main import ( + "errors" "net" "os" "strings" "github.com/spiffe/spire/pkg/common/util" - "github.com/zeebo/errs" ) func (c *Config) getWorkloadAPIAddr() (net.Addr, error) { @@ -23,33 +23,33 @@ func (c *Config) getServerAPITargetName() string { func (c *Config) validateOS() (err error) { switch { case c.ACME == nil && c.ListenSocketPath == "" && c.ServingCertFile == nil && c.InsecureAddr == "": - return errs.New("either acme, serving_cert_file, insecure_addr or listen_socket_path must be configured") + return errors.New("either acme, serving_cert_file, insecure_addr or listen_socket_path must be configured") case c.ACME != nil && c.ServingCertFile != nil: - return errs.New("acme and serving_cert_file are mutually exclusive") + return errors.New("acme and serving_cert_file are mutually exclusive") case c.ACME != nil && c.ListenSocketPath != "": - return errs.New("listen_socket_path and the acme section are mutually exclusive") + return errors.New("listen_socket_path and the acme section are mutually exclusive") case c.ServingCertFile != nil && c.InsecureAddr != "": - return errs.New("serving_cert_file and insecure_addr are mutually exclusive") + return errors.New("serving_cert_file and insecure_addr are mutually exclusive") case c.ServingCertFile != nil && c.ListenSocketPath != "": - return errs.New("serving_cert_file and listen_socket_path are mutually exclusive") + return errors.New("serving_cert_file and listen_socket_path are mutually exclusive") case c.ACME != nil && c.InsecureAddr != "": - return errs.New("acme and insecure_addr are mutually exclusive") + return errors.New("acme and insecure_addr are mutually exclusive") case c.InsecureAddr != "" && c.ListenSocketPath != "": - return errs.New("insecure_addr and listen_socket_path are mutually exclusive") + return errors.New("insecure_addr and listen_socket_path are mutually exclusive") } if c.ServerAPI != nil { if c.ServerAPI.Address == "" { - return errs.New("address must be configured in the server_api configuration section") + return errors.New("address must be configured in the server_api configuration section") } if !strings.HasPrefix(c.ServerAPI.Address, "unix:") { - return errs.New("address must use the unix name system in the server_api configuration section") + return errors.New("address must use the unix name system in the server_api configuration section") } } if c.WorkloadAPI != nil { if c.WorkloadAPI.SocketPath == "" { - return errs.New("socket_path must be configured in the workload_api configuration section") + return errors.New("socket_path must be configured in the workload_api configuration section") } } diff --git a/support/oidc-discovery-provider/main_windows.go b/support/oidc-discovery-provider/main_windows.go index a05b5fd32e..55d24ebdb6 100644 --- a/support/oidc-discovery-provider/main_windows.go +++ b/support/oidc-discovery-provider/main_windows.go @@ -3,6 +3,7 @@ package main import ( + "errors" "fmt" "net" "path/filepath" @@ -10,7 +11,6 @@ import ( "github.com/Microsoft/go-winio" "github.com/spiffe/spire/pkg/common/namedpipe" "github.com/spiffe/spire/pkg/common/sddl" - "github.com/zeebo/errs" ) func (c *Config) getWorkloadAPIAddr() (net.Addr, error) { @@ -25,29 +25,29 @@ func (c *Config) getServerAPITargetName() string { func (c *Config) validateOS() (err error) { switch { case c.ACME == nil && c.Experimental.ListenNamedPipeName == "" && c.ServingCertFile == nil && c.InsecureAddr == "": - return errs.New("either acme, serving_cert_file, insecure_addr or listen_named_pipe_name must be configured") + return errors.New("either acme, serving_cert_file, insecure_addr or listen_named_pipe_name must be configured") case c.ACME != nil && c.ServingCertFile != nil: - return errs.New("acme and serving_cert_file are mutually exclusive") + return errors.New("acme and serving_cert_file are mutually exclusive") case c.ACME != nil && c.Experimental.ListenNamedPipeName != "": - return errs.New("listen_named_pipe_name and the acme section are mutually exclusive") + return errors.New("listen_named_pipe_name and the acme section are mutually exclusive") case c.ACME != nil && c.InsecureAddr != "": - return errs.New("acme and insecure_addr are mutually exclusive") + return errors.New("acme and insecure_addr are mutually exclusive") case c.ServingCertFile != nil && c.InsecureAddr != "": - return errs.New("serving_cert_file and insecure_addr are mutually exclusive") + return errors.New("serving_cert_file and insecure_addr are mutually exclusive") case c.ServingCertFile != nil && c.Experimental.ListenNamedPipeName != "": - return errs.New("serving_cert_file and listen_named_pipe_name are mutually exclusive") + return errors.New("serving_cert_file and listen_named_pipe_name are mutually exclusive") case c.InsecureAddr != "" && c.Experimental.ListenNamedPipeName != "": - return errs.New("insecure_addr and listen_named_pipe_name are mutually exclusive") + return errors.New("insecure_addr and listen_named_pipe_name are mutually exclusive") } if c.ServerAPI != nil { if c.ServerAPI.Experimental.NamedPipeName == "" { - return errs.New("named_pipe_name must be configured in the server_api configuration section") + return errors.New("named_pipe_name must be configured in the server_api configuration section") } } if c.WorkloadAPI != nil { if c.WorkloadAPI.Experimental.NamedPipeName == "" { - return errs.New("named_pipe_name must be configured in the workload_api configuration section") + return errors.New("named_pipe_name must be configured in the workload_api configuration section") } } diff --git a/support/oidc-discovery-provider/server_api.go b/support/oidc-discovery-provider/server_api.go index 74724f1b36..5a9c98c444 100644 --- a/support/oidc-discovery-provider/server_api.go +++ b/support/oidc-discovery-provider/server_api.go @@ -12,7 +12,6 @@ import ( bundlev1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/bundle/v1" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/pkg/common/util" - "github.com/zeebo/errs" "google.golang.org/grpc" "google.golang.org/protobuf/proto" ) @@ -51,7 +50,7 @@ func NewServerAPISource(config ServerAPISourceConfig) (*ServerAPISource, error) conn, err := util.GRPCDialContext(context.Background(), config.GRPCTarget) if err != nil { - return nil, errs.Wrap(err) + return nil, err } ctx, cancel := context.WithCancel(context.Background()) diff --git a/support/oidc-discovery-provider/workload_api.go b/support/oidc-discovery-provider/workload_api.go index 8db442d4f3..caaabf9c78 100644 --- a/support/oidc-discovery-provider/workload_api.go +++ b/support/oidc-discovery-provider/workload_api.go @@ -16,7 +16,6 @@ import ( "github.com/spiffe/go-spiffe/v2/workloadapi" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/pkg/common/util" - "github.com/zeebo/errs" ) const ( @@ -56,19 +55,19 @@ func NewWorkloadAPISource(config WorkloadAPISourceConfig) (*WorkloadAPISource, e if config.Addr != nil { o, err := util.GetWorkloadAPIClientOption(config.Addr) if err != nil { - return nil, errs.Wrap(err) + return nil, err } opts = append(opts, o) } trustDomain, err := spiffeid.TrustDomainFromString(config.TrustDomain) if err != nil { - return nil, errs.Wrap(err) + return nil, err } client, err := workloadapi.New(context.Background(), opts...) if err != nil { - return nil, errs.Wrap(err) + return nil, err } ctx, cancel := context.WithCancel(context.Background()) From 3c1d15651e0b2f03a43be6fe4510c6247f2b56bc Mon Sep 17 00:00:00 2001 From: Ryan Turner Date: Mon, 16 Dec 2024 10:13:08 -0800 Subject: [PATCH 2/3] Address linter warnings Signed-off-by: Ryan Turner --- pkg/common/bundleutil/unmarshal.go | 2 +- pkg/common/jwtsvid/validate.go | 6 ++---- pkg/common/jwtutil/keyset.go | 4 ++-- pkg/common/plugin/azure/msi.go | 4 ++-- pkg/server/bundle/client/client.go | 4 ++-- pkg/server/ca/manager/journal.go | 2 +- pkg/server/ca/manager/manager.go | 2 +- pkg/server/ca/manager/slot.go | 4 ++-- support/oidc-discovery-provider/config.go | 12 ++++++------ 9 files changed, 19 insertions(+), 21 deletions(-) diff --git a/pkg/common/bundleutil/unmarshal.go b/pkg/common/bundleutil/unmarshal.go index 4173e44b3e..ff86b79a17 100644 --- a/pkg/common/bundleutil/unmarshal.go +++ b/pkg/common/bundleutil/unmarshal.go @@ -42,7 +42,7 @@ func unmarshal(trustDomain spiffeid.TrustDomain, doc *bundleDoc) (*spiffebundle. return nil, fmt.Errorf("missing key ID in jwt-svid entry %d", i) } if err := bundle.AddJWTAuthority(key.KeyID, key.Key); err != nil { - return nil, fmt.Errorf("failed to add jwt-svid entry %d: %v", i, err) + return nil, fmt.Errorf("failed to add jwt-svid entry %d: %w", i, err) } case "": return nil, fmt.Errorf("missing use for key entry %d", i) diff --git a/pkg/common/jwtsvid/validate.go b/pkg/common/jwtsvid/validate.go index 7ee3f16e17..33e46fa349 100644 --- a/pkg/common/jwtsvid/validate.go +++ b/pkg/common/jwtsvid/validate.go @@ -40,7 +40,7 @@ func (t *keyStore) FindPublicKey(_ context.Context, td spiffeid.TrustDomain, key func ValidateToken(ctx context.Context, token string, keyStore KeyStore, audience []string) (spiffeid.ID, map[string]any, error) { tok, err := jwt.ParseSigned(token, AllowedSignatureAlgorithms) if err != nil { - return spiffeid.ID{}, nil, fmt.Errorf("unable to parse JWT token: %v", err) + return spiffeid.ID{}, nil, fmt.Errorf("unable to parse JWT token: %w", err) } if len(tok.Headers) != 1 { @@ -65,7 +65,7 @@ func ValidateToken(ctx context.Context, token string, keyStore KeyStore, audienc } spiffeID, err := spiffeid.FromString(claims.Subject) if err != nil { - return spiffeid.ID{}, nil, fmt.Errorf("token has in invalid subject claim: %v", err) + return spiffeid.ID{}, nil, fmt.Errorf("token has in invalid subject claim: %w", err) } // Construct the trust domain id from the SPIFFE ID and look up key by ID @@ -92,8 +92,6 @@ func ValidateToken(ctx context.Context, token string, keyStore KeyStore, audienc err = errors.New("token has expired") case errors.Is(err, jwt.ErrInvalidAudience): err = fmt.Errorf("expected audience in %q (audience=%q)", audience, claims.Audience) - default: - err = err } return spiffeid.ID{}, nil, err } diff --git a/pkg/common/jwtutil/keyset.go b/pkg/common/jwtutil/keyset.go index 56f07baf30..a233dc2cf6 100644 --- a/pkg/common/jwtutil/keyset.go +++ b/pkg/common/jwtutil/keyset.go @@ -114,7 +114,7 @@ func DiscoverKeySetURI(ctx context.Context, configURL string) (string, error) { JWKSURI string `json:"jwks_uri"` }{} if err := json.NewDecoder(resp.Body).Decode(config); err != nil { - return "", fmt.Errorf("failed to decode configuration: %v", err) + return "", fmt.Errorf("failed to decode configuration: %w", err) } if config.JWKSURI == "" { return "", errors.New("configuration missing JWKS URI") @@ -141,7 +141,7 @@ func FetchKeySet(ctx context.Context, jwksURI string) (*jose.JSONWebKeySet, erro jwks := new(jose.JSONWebKeySet) if err := json.NewDecoder(resp.Body).Decode(jwks); err != nil { - return nil, fmt.Errorf("failed to decode key set: %v", err) + return nil, fmt.Errorf("failed to decode key set: %w", err) } return jwks, nil diff --git a/pkg/common/plugin/azure/msi.go b/pkg/common/plugin/azure/msi.go index bb5461eb2f..129c4dbdde 100644 --- a/pkg/common/plugin/azure/msi.go +++ b/pkg/common/plugin/azure/msi.go @@ -79,7 +79,7 @@ func FetchMSIToken(cl HTTPClient, resource string) (string, error) { }{} if err := json.NewDecoder(resp.Body).Decode(&r); err != nil { - return "", fmt.Errorf("unable to decode response: %v", err) + return "", fmt.Errorf("unable to decode response: %w", err) } if r.AccessToken == "" { @@ -107,7 +107,7 @@ func FetchInstanceMetadata(cl HTTPClient) (*InstanceMetadata, error) { metadata := new(InstanceMetadata) if err := json.NewDecoder(resp.Body).Decode(metadata); err != nil { - return nil, fmt.Errorf("unable to decode response: %v", err) + return nil, fmt.Errorf("unable to decode response: %w", err) } switch { diff --git a/pkg/server/bundle/client/client.go b/pkg/server/bundle/client/client.go index 3cd3d9f7fb..009dc3721a 100644 --- a/pkg/server/bundle/client/client.go +++ b/pkg/server/bundle/client/client.go @@ -91,10 +91,10 @@ func (c *client) FetchBundle(context.Context) (*spiffebundle.Bundle, error) { var hostnameError x509.HostnameError if errors.As(err, &hostnameError) && c.c.SPIFFEAuth == nil && len(hostnameError.Certificate.URIs) > 0 { if id, idErr := spiffeid.FromString(hostnameError.Certificate.URIs[0].String()); idErr == nil { - return nil, fmt.Errorf("failed to authenticate bundle endpoint using web authentication but the server certificate contains SPIFFE ID %q: maybe use https_spiffe instead of https_web: %v", id, err) + return nil, fmt.Errorf("failed to authenticate bundle endpoint using web authentication but the server certificate contains SPIFFE ID %q: maybe use https_spiffe instead of https_web: %w", id, err) } } - return nil, fmt.Errorf("failed to fetch bundle: %v", err) + return nil, fmt.Errorf("failed to fetch bundle: %w", err) } defer resp.Body.Close() diff --git a/pkg/server/ca/manager/journal.go b/pkg/server/ca/manager/journal.go index 0c72789343..be95fad938 100644 --- a/pkg/server/ca/manager/journal.go +++ b/pkg/server/ca/manager/journal.go @@ -314,7 +314,7 @@ func loadJournalFromDS(ctx context.Context, config *journalConfig) (*Journal, er j.caJournalID = caJournal.ID if err := proto.Unmarshal(caJournal.Data, j.entries); err != nil { - return nil, fmt.Errorf("unable to unmarshal entries from CA journal record: %v", err) + return nil, fmt.Errorf("unable to unmarshal entries from CA journal record: %w", err) } return j, nil } diff --git a/pkg/server/ca/manager/manager.go b/pkg/server/ca/manager/manager.go index a6c4b9cc07..a799f0476f 100644 --- a/pkg/server/ca/manager/manager.go +++ b/pkg/server/ca/manager/manager.go @@ -741,7 +741,7 @@ func (m *Manager) notify(ctx context.Context, event string, advise bool, pre fun } } if allErrs != nil { - return fmt.Errorf("one or more notifiers returned an error: %v", allErrs) + return fmt.Errorf("one or more notifiers returned an error: %w", allErrs) } return nil diff --git a/pkg/server/ca/manager/slot.go b/pkg/server/ca/manager/slot.go index eb4ee7a232..fa0be6a33c 100644 --- a/pkg/server/ca/manager/slot.go +++ b/pkg/server/ca/manager/slot.go @@ -349,14 +349,14 @@ func (s *SlotLoader) loadX509CASlotFromEntry(ctx context.Context, entry *journal cert, err := x509.ParseCertificate(entry.Certificate) if err != nil { - return nil, "", fmt.Errorf("unable to parse CA certificate: %v", err) + return nil, "", fmt.Errorf("unable to parse CA certificate: %w", err) } var upstreamChain []*x509.Certificate for _, certDER := range entry.UpstreamChain { cert, err := x509.ParseCertificate(certDER) if err != nil { - return nil, "", fmt.Errorf("unable to parse upstream chain certificate: %v", err) + return nil, "", fmt.Errorf("unable to parse upstream chain certificate: %w", err) } upstreamChain = append(upstreamChain, cert) } diff --git a/support/oidc-discovery-provider/config.go b/support/oidc-discovery-provider/config.go index 36932449ae..493c76cf14 100644 --- a/support/oidc-discovery-provider/config.go +++ b/support/oidc-discovery-provider/config.go @@ -189,7 +189,7 @@ type experimentalWorkloadAPIConfig struct { func LoadConfig(path string) (*Config, error) { hclBytes, err := os.ReadFile(path) if err != nil { - return nil, fmt.Errorf("unable to load configuration: %v", err) + return nil, fmt.Errorf("unable to load configuration: %w", err) } return ParseConfig(string(hclBytes)) } @@ -197,7 +197,7 @@ func LoadConfig(path string) (*Config, error) { func ParseConfig(hclConfig string) (_ *Config, err error) { c := new(Config) if err := hcl.Decode(c, hclConfig); err != nil { - return nil, fmt.Errorf("unable to decode configuration: %v", err) + return nil, fmt.Errorf("unable to decode configuration: %w", err) } if c.LogLevel == "" { @@ -238,13 +238,13 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { addr, err := net.ResolveTCPAddr("tcp", c.ServingCertFile.RawAddr) if err != nil { - return nil, fmt.Errorf("invalid addr in the serving_cert_file configuration section: %v", err) + return nil, fmt.Errorf("invalid addr in the serving_cert_file configuration section: %w", err) } c.ServingCertFile.Addr = addr c.ServingCertFile.FileSyncInterval, err = parseDurationField(c.ServingCertFile.RawFileSyncInterval, defaultFileSyncInterval) if err != nil { - return nil, fmt.Errorf("invalid file_sync_interval in the serving_cert_file configuration section: %v", err) + return nil, fmt.Errorf("invalid file_sync_interval in the serving_cert_file configuration section: %w", err) } } @@ -253,7 +253,7 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { if c.ServerAPI != nil { c.ServerAPI.PollInterval, err = parseDurationField(c.ServerAPI.RawPollInterval, defaultPollInterval) if err != nil { - return nil, fmt.Errorf("invalid poll_interval in the server_api configuration section: %v", err) + return nil, fmt.Errorf("invalid poll_interval in the server_api configuration section: %w", err) } methodCount++ } @@ -264,7 +264,7 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { } c.WorkloadAPI.PollInterval, err = parseDurationField(c.WorkloadAPI.RawPollInterval, defaultPollInterval) if err != nil { - return nil, fmt.Errorf("invalid poll_interval in the workload_api configuration section: %v", err) + return nil, fmt.Errorf("invalid poll_interval in the workload_api configuration section: %w", err) } methodCount++ } From ad10f2b7f7c47142f1e1f159df3f8740ef979b6a Mon Sep 17 00:00:00 2001 From: Ryan Turner Date: Tue, 7 Jan 2025 15:06:27 -0800 Subject: [PATCH 3/3] Address review comments Signed-off-by: Ryan Turner --- pkg/common/catalog/closers.go | 4 +- pkg/server/datastore/sqlstore/errors.go | 82 +++++++------------ pkg/server/datastore/sqlstore/errors_test.go | 24 +----- pkg/server/datastore/sqlstore/mysql.go | 2 +- .../datastore/sqlstore/sqlstore_test.go | 1 - support/oidc-discovery-provider/config.go | 9 +- .../config_posix_test.go | 4 +- .../oidc-discovery-provider/config_test.go | 4 +- .../config_windows_test.go | 4 +- 9 files changed, 47 insertions(+), 87 deletions(-) diff --git a/pkg/common/catalog/closers.go b/pkg/common/catalog/closers.go index dc7571dc16..d72a186fae 100644 --- a/pkg/common/catalog/closers.go +++ b/pkg/common/catalog/closers.go @@ -14,9 +14,7 @@ func (cs closerGroup) Close() error { // Close in reverse order. var errs error for i := len(cs) - 1; i >= 0; i-- { - if err := cs[i].Close(); err != nil { - errs = errors.Join(errs, err) - } + errs = errors.Join(errs, cs[i].Close()) } return errs diff --git a/pkg/server/datastore/sqlstore/errors.go b/pkg/server/datastore/sqlstore/errors.go index 364679f8d6..1aaf152470 100644 --- a/pkg/server/datastore/sqlstore/errors.go +++ b/pkg/server/datastore/sqlstore/errors.go @@ -14,22 +14,6 @@ type sqlError struct { msg string } -func newSQLError(fmtMsg string, args ...any) error { - return &sqlError{ - msg: fmt.Sprintf(fmtMsg, args...), - } -} - -func newWrappedSQLError(err error) error { - if err == nil { - return nil - } - - return &sqlError{ - err: err, - } -} - func (s *sqlError) Error() string { if s == nil { return "" @@ -42,19 +26,6 @@ func (s *sqlError) Error() string { return fmt.Sprintf("%s: %s", datastoreSQLErrorPrefix, s.msg) } -func (s *sqlError) Is(err error) bool { - if s == nil { - return false - } - - sErr, ok := err.(*sqlError) - if !ok { - return false - } - - return s.msg == sErr.msg -} - func (s *sqlError) Unwrap() error { if s == nil { return nil @@ -68,22 +39,6 @@ type validationError struct { msg string } -func newValidationError(fmtMsg string, args ...any) error { - return &validationError{ - msg: fmt.Sprintf(fmtMsg, args...), - } -} - -func newWrappedValidationError(err error) error { - if err == nil { - return nil - } - - return &validationError{ - err: err, - } -} - func (v *validationError) Error() string { if v == nil { return "" @@ -96,23 +51,42 @@ func (v *validationError) Error() string { return fmt.Sprintf("%s: %s", datastoreValidationErrorPrefix, v.msg) } -func (v *validationError) Is(err error) bool { +func (v *validationError) Unwrap() error { if v == nil { - return false + return nil + } + + return v.err +} + +func newSQLError(fmtMsg string, args ...any) error { + return &sqlError{ + msg: fmt.Sprintf(fmtMsg, args...), } +} - vErr, ok := err.(*validationError) - if !ok { - return false +func newWrappedSQLError(err error) error { + if err == nil { + return nil } - return v.msg == vErr.msg + return &sqlError{ + err: err, + } } -func (v *validationError) Unwrap() error { - if v == nil { +func newValidationError(fmtMsg string, args ...any) error { + return &validationError{ + msg: fmt.Sprintf(fmtMsg, args...), + } +} + +func newWrappedValidationError(err error) error { + if err == nil { return nil } - return v.err + return &validationError{ + err: err, + } } diff --git a/pkg/server/datastore/sqlstore/errors_test.go b/pkg/server/datastore/sqlstore/errors_test.go index 6de7eb4012..5d2079aa81 100644 --- a/pkg/server/datastore/sqlstore/errors_test.go +++ b/pkg/server/datastore/sqlstore/errors_test.go @@ -12,11 +12,7 @@ func TestSQLError(t *testing.T) { assert.EqualError(t, err, "datastore-sql: an error with two dynamic fields: hello, 1") var sErr *sqlError - assert.True(t, errors.As(err, &sErr)) - - assert.True(t, errors.Is(err, &sqlError{ - msg: "an error with two dynamic fields: hello, 1", - })) + assert.ErrorAs(t, err, &sErr) } func TestWrappedSQLError(t *testing.T) { @@ -32,11 +28,7 @@ func TestWrappedSQLError(t *testing.T) { assert.EqualError(t, err, "datastore-sql: foo") var sErr *sqlError - assert.True(t, errors.As(err, &sErr)) - - assert.True(t, errors.Is(err, &sqlError{ - err: wrappedErr, - })) + assert.ErrorAs(t, err, &sErr) }) } @@ -45,11 +37,7 @@ func TestValidationError(t *testing.T) { assert.EqualError(t, err, "datastore-validation: an error with two dynamic fields: hello, 1") var vErr *validationError - assert.True(t, errors.As(err, &vErr)) - - assert.True(t, errors.Is(err, &validationError{ - msg: "an error with two dynamic fields: hello, 1", - })) + assert.ErrorAs(t, err, &vErr) } func TestWrappedValidationError(t *testing.T) { @@ -65,10 +53,6 @@ func TestWrappedValidationError(t *testing.T) { assert.EqualError(t, err, "datastore-validation: bar") var vErr *validationError - assert.True(t, errors.As(err, &vErr)) - - assert.True(t, errors.Is(err, &validationError{ - err: wrappedErr, - })) + assert.ErrorAs(t, err, &vErr) }) } diff --git a/pkg/server/datastore/sqlstore/mysql.go b/pkg/server/datastore/sqlstore/mysql.go index 69b93acf5f..a7ee2faeff 100644 --- a/pkg/server/datastore/sqlstore/mysql.go +++ b/pkg/server/datastore/sqlstore/mysql.go @@ -173,7 +173,7 @@ func validateMySQLConfig(cfg *configuration, isReadOnly bool) error { } if !opts.ParseTime { - return newWrappedSQLError(errors.New("invalid mysql config: missing parseTime=true param in connection_string")) + return newSQLError("invalid mysql config: missing parseTime=true param in connection_string") } return nil diff --git a/pkg/server/datastore/sqlstore/sqlstore_test.go b/pkg/server/datastore/sqlstore/sqlstore_test.go index f13beb4d51..c570a93bf4 100644 --- a/pkg/server/datastore/sqlstore/sqlstore_test.go +++ b/pkg/server/datastore/sqlstore/sqlstore_test.go @@ -250,7 +250,6 @@ func (s *PluginSuite) TestBundleCRUD() { // fetch non-existent fb, err := s.ds.FetchBundle(ctx, "spiffe://foo") - s.T().Logf("err type: %T", err) s.Require().NoError(err) s.Require().Nil(fb) diff --git a/support/oidc-discovery-provider/config.go b/support/oidc-discovery-provider/config.go index 5ced8d7dfd..f0cb81ed69 100644 --- a/support/oidc-discovery-provider/config.go +++ b/support/oidc-discovery-provider/config.go @@ -299,8 +299,13 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { } if c.JWTIssuer != "" { jwtIssuer, err := url.Parse(c.JWTIssuer) - if err != nil || jwtIssuer.Scheme == "" || jwtIssuer.Host == "" { - return nil, errors.New("the jwt_issuer url could not be parsed") + switch { + case err != nil: + return nil, fmt.Errorf("the jwt_issuer url could not be parsed: %w", err) + case jwtIssuer.Scheme == "": + return nil, errors.New("the jwt_issuer url must contain a scheme") + case jwtIssuer.Host == "": + return nil, errors.New("the jwt_issuer url must contain a host") } } return c, nil diff --git a/support/oidc-discovery-provider/config_posix_test.go b/support/oidc-discovery-provider/config_posix_test.go index bba9706483..3401c6818f 100644 --- a/support/oidc-discovery-provider/config_posix_test.go +++ b/support/oidc-discovery-provider/config_posix_test.go @@ -697,7 +697,7 @@ func parseConfigCasesOS() []parseConfigCase { address = "unix:///some/socket/path" } `, - err: "the jwt_issuer url could not be parsed", + err: "the jwt_issuer url must contain a scheme", }, { name: "JWT issuer with missing host", @@ -712,7 +712,7 @@ func parseConfigCasesOS() []parseConfigCase { address = "unix:///some/socket/path" } `, - err: "the jwt_issuer url could not be parsed", + err: "the jwt_issuer url must contain a host", }, { name: "JWT issuer is invalid", diff --git a/support/oidc-discovery-provider/config_test.go b/support/oidc-discovery-provider/config_test.go index ebe3e6b1a4..f194d3fdc6 100644 --- a/support/oidc-discovery-provider/config_test.go +++ b/support/oidc-discovery-provider/config_test.go @@ -27,7 +27,7 @@ func TestLoadConfig(t *testing.T) { require.Error(err) require.Contains(err.Error(), "unable to load configuration:") - err = os.WriteFile(confPath, []byte(minimalEnvServerAPIConfig), 0600) + err = os.WriteFile(confPath, []byte(minimalEnvServerAPIConfig), 0o600) require.NoError(err) os.Setenv("SPIFFE_TRUST_DOMAIN", "domain.test") @@ -45,7 +45,7 @@ func TestLoadConfig(t *testing.T) { ServerAPI: serverAPIConfig, }, config) - err = os.WriteFile(confPath, []byte(minimalServerAPIConfig), 0600) + err = os.WriteFile(confPath, []byte(minimalServerAPIConfig), 0o600) require.NoError(err) config, err = LoadConfig(confPath, false) diff --git a/support/oidc-discovery-provider/config_windows_test.go b/support/oidc-discovery-provider/config_windows_test.go index 728b81f440..7fd2efc266 100644 --- a/support/oidc-discovery-provider/config_windows_test.go +++ b/support/oidc-discovery-provider/config_windows_test.go @@ -645,7 +645,7 @@ func parseConfigCasesOS() []parseConfigCase { } } `, - err: "the jwt_issuer url could not be parsed", + err: "the jwt_issuer url must contain a scheme", }, { name: "JWT issuer with missing host", @@ -663,7 +663,7 @@ func parseConfigCasesOS() []parseConfigCase { } } `, - err: "the jwt_issuer url could not be parsed", + err: "the jwt_issuer url must contain a host", }, { name: "JWT issuer is invalid",