diff --git a/pkg/cosign/env/env.go b/pkg/cosign/env/env.go index 599df2e0f4a..5c26d4f5168 100644 --- a/pkg/cosign/env/env.go +++ b/pkg/cosign/env/env.go @@ -44,12 +44,13 @@ func (v Variable) String() string { const ( // Cosign environment variables - VariableExperimental Variable = "COSIGN_EXPERIMENTAL" - VariableDockerMediaTypes Variable = "COSIGN_DOCKER_MEDIA_TYPES" - VariablePassword Variable = "COSIGN_PASSWORD" - VariablePKCS11Pin Variable = "COSIGN_PKCS11_PIN" - VariablePKCS11ModulePath Variable = "COSIGN_PKCS11_MODULE_PATH" - VariableRepository Variable = "COSIGN_REPOSITORY" + VariableExperimental Variable = "COSIGN_EXPERIMENTAL" + VariableDockerMediaTypes Variable = "COSIGN_DOCKER_MEDIA_TYPES" + VariablePassword Variable = "COSIGN_PASSWORD" + VariablePKCS11Pin Variable = "COSIGN_PKCS11_PIN" + VariablePKCS11ModulePath Variable = "COSIGN_PKCS11_MODULE_PATH" + VariablePKCS11IgnoreCertificate Variable = "COSIGN_PKCS11_IGNORE_CERTIFICATE" + VariableRepository Variable = "COSIGN_REPOSITORY" // Sigstore environment variables VariableSigstoreCTLogPublicKeyFile Variable = "SIGSTORE_CT_LOG_PUBLIC_KEY_FILE" @@ -102,6 +103,11 @@ var ( Expects: "string with a module-path", Sensitive: false, }, + VariablePKCS11IgnoreCertificate: { + Description: "disables loading certificates with PKCS11", + Expects: "1 if loading certificates should be disabled (0 by default)", + Sensitive: false, + }, VariableRepository: { Description: "can be used to store signatures in an alternate location", Expects: "string with a repository", diff --git a/pkg/cosign/pkcs11key/pkcs11key.go b/pkg/cosign/pkcs11key/pkcs11key.go index 0c735c82f80..c034a3f4fac 100644 --- a/pkg/cosign/pkcs11key/pkcs11key.go +++ b/pkg/cosign/pkcs11key/pkcs11key.go @@ -177,10 +177,15 @@ func GetKeyWithURIConfig(config *Pkcs11UriConfig, askForPinIfNeeded bool) (*Key, // Key's corresponding cert might not exist, // therefore, we do not fail if it is the case. var cert *x509.Certificate - if len(config.KeyID) != 0 { - cert, _ = ctx.FindCertificate(config.KeyID, nil, nil) - } else if len(config.KeyLabel) != 0 { - cert, _ = ctx.FindCertificate(nil, config.KeyLabel, nil) + + ignoreCert := env.Getenv(env.VariablePKCS11IgnoreCertificate) == "1" + + if !ignoreCert { + if len(config.KeyID) != 0 { + cert, _ = ctx.FindCertificate(config.KeyID, nil, nil) + } else if len(config.KeyLabel) != 0 { + cert, _ = ctx.FindCertificate(nil, config.KeyLabel, nil) + } } return &Key{ctx: ctx, signer: signer, cert: cert}, nil diff --git a/test/pkcs11_test.go b/test/pkcs11_test.go index 9b65e6c229c..bfe031b9d1c 100644 --- a/test/pkcs11_test.go +++ b/test/pkcs11_test.go @@ -45,6 +45,7 @@ import ( "encoding/hex" "encoding/pem" "fmt" + "io" "math/big" "os" "strings" @@ -214,6 +215,64 @@ func TestListKeysUrisCmd(t *testing.T) { } } +func TestCertificateIgnored(t *testing.T) { + ctx := context.Background() + + tokens, err := GetTokens(ctx, modulePath) + if err != nil { + t.Fatal(err) + } + + bTokenFound := false + var slotID uint + for _, token := range tokens { + if token.TokenInfo.Label == tokenLabel { + bTokenFound = true + slotID = token.Slot + break + } + } + if !bTokenFound { + t.Fatalf("token with label '%s' not found", tokenLabel) + } + + err = importKey(slotID) + if err != nil { + t.Fatal(err) + } + defer deleteKey(slotID) + + pkcs11UriConfig := pkcs11key.NewPkcs11UriConfig() + err = pkcs11UriConfig.Parse(uri) + if err != nil { + t.Fatal(err) + } + + const envvar = "COSIGN_PKCS11_IGNORE_CERTIFICATE" + + if err := os.Setenv(envvar, "1"); err != nil { + t.Fatal(err) + } + + defer os.Setenv(envvar, "") + + sk, err := pkcs11key.GetKeyWithURIConfig(pkcs11UriConfig, true) + if err != nil { + t.Fatal(err) + } + + defer sk.Close() + + cert, err := sk.Certificate() + if err != nil { + t.Fatal(err) + } + + if cert != nil { + t.Fatalf("expected certificate to be ignored while loading") + } +} + func TestSignAndVerify(t *testing.T) { ctx := context.Background() @@ -350,7 +409,7 @@ func importKey(slotID uint) error { keyLabelBytes := []byte(keyLabel) r := strings.NewReader(rsaPrivKey) - pemBytes, err = os.ReadAll(r) + pemBytes, err = io.ReadAll(r) if err != nil { return fmt.Errorf("unable to read pem") }