diff --git a/sdk/keyvault/azkeys/client_test.go b/sdk/keyvault/azkeys/client_test.go index a27e854a759f..747111b7caee 100644 --- a/sdk/keyvault/azkeys/client_test.go +++ b/sdk/keyvault/azkeys/client_test.go @@ -39,9 +39,7 @@ func TestConstructor(t *testing.T) { func TestCreateKeyRSA(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -68,8 +66,7 @@ func TestCreateKeyRSA(t *testing.T) { } func TestCreateKeyRSATags(t *testing.T) { - stop := startTest(t) - defer stop() + startTest(t, REGULARTEST) client, err := createClient(t, REGULARTEST) require.NoError(t, err) @@ -98,9 +95,7 @@ func TestCreateKeyRSATags(t *testing.T) { func TestCreateECKey(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -124,9 +119,7 @@ func TestCreateECKey(t *testing.T) { func TestCreateOCTKey(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -154,9 +147,7 @@ func TestCreateOCTKey(t *testing.T) { func TestListKeys(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -193,9 +184,7 @@ func TestListKeys(t *testing.T) { func TestGetKey(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -221,9 +210,7 @@ func TestGetKey(t *testing.T) { func TestDeleteKey(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -271,8 +258,7 @@ func TestDeleteKey(t *testing.T) { } func TestBeginDeleteKeyRehydrate(t *testing.T) { - stop := startTest(t) - defer stop() + startTest(t, REGULARTEST) client, err := createClient(t, testTypes[0]) require.NoError(t, err) @@ -321,9 +307,7 @@ func TestBeginDeleteKeyRehydrate(t *testing.T) { func TestBackupKey(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -385,9 +369,7 @@ func TestBackupKey(t *testing.T) { func TestRecoverDeletedKey(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -428,9 +410,7 @@ func TestRecoverDeletedKey(t *testing.T) { func TestUpdateKeyProperties(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) err := recording.SetBodilessMatcher(t, nil) require.NoError(t, err) @@ -466,8 +446,7 @@ func TestUpdateKeyProperties(t *testing.T) { func TestUpdateKeyPropertiesImmutable(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -519,9 +498,7 @@ func TestUpdateKeyPropertiesImmutable(t *testing.T) { func TestListDeletedKeys(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -575,9 +552,7 @@ func TestListDeletedKeys(t *testing.T) { func TestListKeyVersions(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -608,9 +583,7 @@ func TestListKeyVersions(t *testing.T) { func TestImportKey(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -645,9 +618,7 @@ func TestGetRandomBytes(t *testing.T) { if testType == REGULARTEST { t.Skip("Managed HSM Only") } - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -665,9 +636,7 @@ func TestGetRandomBytes(t *testing.T) { func TestGetDeletedKey(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -696,9 +665,7 @@ func TestGetDeletedKey(t *testing.T) { func TestRotateKey(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -730,9 +697,7 @@ func TestRotateKey(t *testing.T) { func TestGetKeyRotationPolicy(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -759,9 +724,7 @@ func TestReleaseKey(t *testing.T) { tn += "_latest" } t.Run(tn, func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -823,9 +786,7 @@ func TestReleaseKey(t *testing.T) { func TestUpdateKeyRotationPolicy(t *testing.T) { for _, testType := range testTypes { t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) { - skipHSM(t, testType) - stop := startTest(t) - defer stop() + startTest(t, testType) client, err := createClient(t, testType) require.NoError(t, err) @@ -858,8 +819,7 @@ func TestUpdateKeyRotationPolicy(t *testing.T) { } func TestClient_EncryptDecrypt(t *testing.T) { - stop := startTest(t) - defer stop() + startTest(t, REGULARTEST) keyName, err := createRandomName(t, "key") require.NoError(t, err) @@ -881,8 +841,7 @@ func TestClient_EncryptDecrypt(t *testing.T) { } func TestClient_WrapUnwrap(t *testing.T) { - stop := startTest(t) - defer stop() + startTest(t, REGULARTEST) keyName, err := createRandomName(t, "key") require.NoError(t, err) @@ -908,8 +867,7 @@ func TestClient_WrapUnwrap(t *testing.T) { } func TestClient_SignVerify(t *testing.T) { - stop := startTest(t) - defer stop() + startTest(t, REGULARTEST) keyName, err := createRandomName(t, "key") require.NoError(t, err) diff --git a/sdk/keyvault/azkeys/utils_test.go b/sdk/keyvault/azkeys/utils_test.go index 23c6ac203f20..898ff8d7df91 100644 --- a/sdk/keyvault/azkeys/utils_test.go +++ b/sdk/keyvault/azkeys/utils_test.go @@ -121,20 +121,16 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func startTest(t *testing.T) func() { +func startTest(t *testing.T, testType string) { + if recording.GetRecordMode() != recording.PlaybackMode && testType == HSMTEST && !enableHSM { + t.Skip("set AZURE_MANAGEDHSM_URL to run this test") + } err := recording.Start(t, pathToPackage, nil) require.NoError(t, err) - return func() { + t.Cleanup(func() { err := recording.Stop(t, nil) require.NoError(t, err) - } -} - -// skipHSM skips live MHSM tests when AZURE_MANAGEDHSM_URL has no value -func skipHSM(t *testing.T, testType string) { - if recording.GetRecordMode() != recording.PlaybackMode && testType == HSMTEST && !enableHSM { - t.Skip("set AZURE_MANAGEDHSM_URL to run this test") - } + }) } func createRandomName(t *testing.T, prefix string) (string, error) {