Skip to content

Commit

Permalink
Prevent azkeys tests incorrectly targeting MHSM (#18318)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Jun 6, 2022
1 parent 7fc129d commit 30ff870
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 76 deletions.
90 changes: 24 additions & 66 deletions sdk/keyvault/azkeys/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ func TestConstructor(t *testing.T) {
func TestCreateKeyRSA(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand All @@ -68,8 +66,7 @@ func TestCreateKeyRSA(t *testing.T) {
}

func TestCreateKeyRSATags(t *testing.T) {
stop := startTest(t)
defer stop()
startTest(t, REGULARTEST)

client, err := createClient(t, REGULARTEST)
require.NoError(t, err)
Expand Down Expand Up @@ -98,9 +95,7 @@ func TestCreateKeyRSATags(t *testing.T) {
func TestCreateECKey(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand All @@ -124,9 +119,7 @@ func TestCreateECKey(t *testing.T) {
func TestCreateOCTKey(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand Down Expand Up @@ -154,9 +147,7 @@ func TestCreateOCTKey(t *testing.T) {
func TestListKeys(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand Down Expand Up @@ -193,9 +184,7 @@ func TestListKeys(t *testing.T) {
func TestGetKey(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand All @@ -221,9 +210,7 @@ func TestGetKey(t *testing.T) {
func TestDeleteKey(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand Down Expand Up @@ -271,8 +258,7 @@ func TestDeleteKey(t *testing.T) {
}

func TestBeginDeleteKeyRehydrate(t *testing.T) {
stop := startTest(t)
defer stop()
startTest(t, REGULARTEST)

client, err := createClient(t, testTypes[0])
require.NoError(t, err)
Expand Down Expand Up @@ -321,9 +307,7 @@ func TestBeginDeleteKeyRehydrate(t *testing.T) {
func TestBackupKey(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand Down Expand Up @@ -385,9 +369,7 @@ func TestBackupKey(t *testing.T) {
func TestRecoverDeletedKey(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand Down Expand Up @@ -428,9 +410,7 @@ func TestRecoverDeletedKey(t *testing.T) {
func TestUpdateKeyProperties(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)
err := recording.SetBodilessMatcher(t, nil)
require.NoError(t, err)

Expand Down Expand Up @@ -466,8 +446,7 @@ func TestUpdateKeyProperties(t *testing.T) {
func TestUpdateKeyPropertiesImmutable(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand Down Expand Up @@ -519,9 +498,7 @@ func TestUpdateKeyPropertiesImmutable(t *testing.T) {
func TestListDeletedKeys(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand Down Expand Up @@ -575,9 +552,7 @@ func TestListDeletedKeys(t *testing.T) {
func TestListKeyVersions(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand Down Expand Up @@ -608,9 +583,7 @@ func TestListKeyVersions(t *testing.T) {
func TestImportKey(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand Down Expand Up @@ -645,9 +618,7 @@ func TestGetRandomBytes(t *testing.T) {
if testType == REGULARTEST {
t.Skip("Managed HSM Only")
}
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand All @@ -665,9 +636,7 @@ func TestGetRandomBytes(t *testing.T) {
func TestGetDeletedKey(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand Down Expand Up @@ -696,9 +665,7 @@ func TestGetDeletedKey(t *testing.T) {
func TestRotateKey(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand Down Expand Up @@ -730,9 +697,7 @@ func TestRotateKey(t *testing.T) {
func TestGetKeyRotationPolicy(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand All @@ -759,9 +724,7 @@ func TestReleaseKey(t *testing.T) {
tn += "_latest"
}
t.Run(tn, func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand Down Expand Up @@ -823,9 +786,7 @@ func TestReleaseKey(t *testing.T) {
func TestUpdateKeyRotationPolicy(t *testing.T) {
for _, testType := range testTypes {
t.Run(fmt.Sprintf("%s_%s", t.Name(), testType), func(t *testing.T) {
skipHSM(t, testType)
stop := startTest(t)
defer stop()
startTest(t, testType)

client, err := createClient(t, testType)
require.NoError(t, err)
Expand Down Expand Up @@ -858,8 +819,7 @@ func TestUpdateKeyRotationPolicy(t *testing.T) {
}

func TestClient_EncryptDecrypt(t *testing.T) {
stop := startTest(t)
defer stop()
startTest(t, REGULARTEST)

keyName, err := createRandomName(t, "key")
require.NoError(t, err)
Expand All @@ -881,8 +841,7 @@ func TestClient_EncryptDecrypt(t *testing.T) {
}

func TestClient_WrapUnwrap(t *testing.T) {
stop := startTest(t)
defer stop()
startTest(t, REGULARTEST)

keyName, err := createRandomName(t, "key")
require.NoError(t, err)
Expand All @@ -908,8 +867,7 @@ func TestClient_WrapUnwrap(t *testing.T) {
}

func TestClient_SignVerify(t *testing.T) {
stop := startTest(t)
defer stop()
startTest(t, REGULARTEST)

keyName, err := createRandomName(t, "key")
require.NoError(t, err)
Expand Down
16 changes: 6 additions & 10 deletions sdk/keyvault/azkeys/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,16 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}

func startTest(t *testing.T) func() {
func startTest(t *testing.T, testType string) {
if recording.GetRecordMode() != recording.PlaybackMode && testType == HSMTEST && !enableHSM {
t.Skip("set AZURE_MANAGEDHSM_URL to run this test")
}
err := recording.Start(t, pathToPackage, nil)
require.NoError(t, err)
return func() {
t.Cleanup(func() {
err := recording.Stop(t, nil)
require.NoError(t, err)
}
}

// skipHSM skips live MHSM tests when AZURE_MANAGEDHSM_URL has no value
func skipHSM(t *testing.T, testType string) {
if recording.GetRecordMode() != recording.PlaybackMode && testType == HSMTEST && !enableHSM {
t.Skip("set AZURE_MANAGEDHSM_URL to run this test")
}
})
}

func createRandomName(t *testing.T, prefix string) (string, error) {
Expand Down

0 comments on commit 30ff870

Please sign in to comment.