diff --git a/featuregate/testing/feature_gate.go b/featuregate/testing/feature_gate.go index 907906ee..9884df0a 100644 --- a/featuregate/testing/feature_gate.go +++ b/featuregate/testing/feature_gate.go @@ -18,18 +18,35 @@ package testing import ( "fmt" + "strings" + "sync" "testing" "k8s.io/component-base/featuregate" ) -// SetFeatureGateDuringTest sets the specified gate to the specified value, and returns a function that restores the original value. -// Failures to set or restore cause the test to fail. +var ( + overrideLock sync.Mutex + featureFlagOverride map[featuregate.Feature]string +) + +func init() { + featureFlagOverride = map[featuregate.Feature]string{} +} + +// SetFeatureGateDuringTest sets the specified gate to the specified value for duration of the test. +// Fails when it detects second call to the same flag or is unable to set or restore feature flag. +// Returns empty cleanup function to maintain the old function signature that uses defer. +// TODO: Remove defer from calls to SetFeatureGateDuringTest and update hack/verify-test-featuregates.sh when we can do large scale code change. +// +// WARNING: Can leak set variable when called in test calling t.Parallel(), however second attempt to set the same feature flag will cause fatal. // // Example use: // // defer featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features., true)() func SetFeatureGateDuringTest(tb testing.TB, gate featuregate.FeatureGate, f featuregate.Feature, value bool) func() { + tb.Helper() + detectParallelOverrideCleanup := detectParallelOverride(tb, f) originalValue := gate.Enabled(f) // Specially handle AllAlpha and AllBeta @@ -50,12 +67,41 @@ func SetFeatureGateDuringTest(tb testing.TB, gate featuregate.FeatureGate, f fea tb.Errorf("error setting %s=%v: %v", f, value, err) } - return func() { + tb.Cleanup(func() { + tb.Helper() + detectParallelOverrideCleanup() if err := gate.(featuregate.MutableFeatureGate).Set(fmt.Sprintf("%s=%v", f, originalValue)); err != nil { tb.Errorf("error restoring %s=%v: %v", f, originalValue, err) } for _, cleanup := range cleanups { cleanup() } + }) + return func() {} +} + +func detectParallelOverride(tb testing.TB, f featuregate.Feature) func() { + tb.Helper() + overrideLock.Lock() + defer overrideLock.Unlock() + beforeOverrideTestName := featureFlagOverride[f] + if beforeOverrideTestName != "" && !sameTestOrSubtest(tb, beforeOverrideTestName) { + tb.Fatalf("Detected parallel setting of a feature gate by both %q and %q", beforeOverrideTestName, tb.Name()) + } + featureFlagOverride[f] = tb.Name() + + return func() { + tb.Helper() + overrideLock.Lock() + defer overrideLock.Unlock() + if afterOverrideTestName := featureFlagOverride[f]; afterOverrideTestName != tb.Name() { + tb.Fatalf("Detected parallel setting of a feature gate between both %q and %q", afterOverrideTestName, tb.Name()) + } + featureFlagOverride[f] = beforeOverrideTestName } } + +func sameTestOrSubtest(tb testing.TB, testName string) bool { + // Assumes that "/" is not used in test names. + return tb.Name() == testName || strings.HasPrefix(tb.Name(), testName+"/") +} diff --git a/featuregate/testing/feature_gate_test.go b/featuregate/testing/feature_gate_test.go index d16d2106..a8ea8782 100644 --- a/featuregate/testing/feature_gate_test.go +++ b/featuregate/testing/feature_gate_test.go @@ -19,6 +19,8 @@ package testing import ( gotest "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "k8s.io/component-base/featuregate" ) @@ -67,8 +69,11 @@ func TestSpecialGates(t *gotest.T) { "stable_default_off_set_on": true, } expect(t, gate, before) + t.Cleanup(func() { + expect(t, gate, before) + }) - cleanupAlpha := SetFeatureGateDuringTest(t, gate, "AllAlpha", true) + SetFeatureGateDuringTest(t, gate, "AllAlpha", true) expect(t, gate, map[featuregate.Feature]bool{ "AllAlpha": true, "AllBeta": false, @@ -89,7 +94,7 @@ func TestSpecialGates(t *gotest.T) { "stable_default_off_set_on": true, }) - cleanupBeta := SetFeatureGateDuringTest(t, gate, "AllBeta", true) + SetFeatureGateDuringTest(t, gate, "AllBeta", true) expect(t, gate, map[featuregate.Feature]bool{ "AllAlpha": true, "AllBeta": true, @@ -109,11 +114,6 @@ func TestSpecialGates(t *gotest.T) { "stable_default_off": false, "stable_default_off_set_on": true, }) - - // run cleanups in reverse order like defer would - cleanupBeta() - cleanupAlpha() - expect(t, gate, before) } func expect(t *gotest.T, gate featuregate.FeatureGate, expect map[featuregate.Feature]bool) { @@ -124,3 +124,127 @@ func expect(t *gotest.T, gate featuregate.FeatureGate, expect map[featuregate.Fe } } } + +func TestSetFeatureGateInTest(t *gotest.T) { + gate := featuregate.NewFeatureGate() + err := gate.Add(map[featuregate.Feature]featuregate.FeatureSpec{ + "feature": {PreRelease: featuregate.Alpha, Default: false}, + }) + require.NoError(t, err) + + assert.False(t, gate.Enabled("feature")) + defer SetFeatureGateDuringTest(t, gate, "feature", true)() + defer SetFeatureGateDuringTest(t, gate, "feature", true)() + + assert.True(t, gate.Enabled("feature")) + t.Run("Subtest", func(t *gotest.T) { + assert.True(t, gate.Enabled("feature")) + }) + + t.Run("ParallelSubtest", func(t *gotest.T) { + assert.True(t, gate.Enabled("feature")) + // Calling t.Parallel in subtest will resume the main test body + t.Parallel() + assert.True(t, gate.Enabled("feature")) + }) + assert.True(t, gate.Enabled("feature")) + + t.Run("OverwriteInSubtest", func(t *gotest.T) { + defer SetFeatureGateDuringTest(t, gate, "feature", false)() + assert.False(t, gate.Enabled("feature")) + }) + assert.True(t, gate.Enabled("feature")) +} + +func TestDetectLeakToMainTest(t *gotest.T) { + t.Cleanup(func() { + featureFlagOverride = map[featuregate.Feature]string{} + }) + gate := featuregate.NewFeatureGate() + err := gate.Add(map[featuregate.Feature]featuregate.FeatureSpec{ + "feature": {PreRelease: featuregate.Alpha, Default: false}, + }) + require.NoError(t, err) + + // Subtest setting feature gate and calling parallel will leak it out + t.Run("LeakingSubtest", func(t *gotest.T) { + fakeT := &ignoreFatalT{T: t} + defer SetFeatureGateDuringTest(fakeT, gate, "feature", true)() + // Calling t.Parallel in subtest will resume the main test body + t.Parallel() + // Leaked false from main test + assert.False(t, gate.Enabled("feature")) + }) + // Leaked true from subtest + assert.True(t, gate.Enabled("feature")) + fakeT := &ignoreFatalT{T: t} + defer SetFeatureGateDuringTest(fakeT, gate, "feature", false)() + assert.True(t, fakeT.fatalRecorded) +} + +func TestDetectLeakToOtherSubtest(t *gotest.T) { + t.Cleanup(func() { + featureFlagOverride = map[featuregate.Feature]string{} + }) + gate := featuregate.NewFeatureGate() + err := gate.Add(map[featuregate.Feature]featuregate.FeatureSpec{ + "feature": {PreRelease: featuregate.Alpha, Default: false}, + }) + require.NoError(t, err) + + subtestName := "Subtest" + // Subtest setting feature gate and calling parallel will leak it out + t.Run(subtestName, func(t *gotest.T) { + fakeT := &ignoreFatalT{T: t} + defer SetFeatureGateDuringTest(fakeT, gate, "feature", true)() + t.Parallel() + }) + // Add suffix to name to prevent tests with the same prefix. + t.Run(subtestName+"Suffix", func(t *gotest.T) { + // Leaked true + assert.True(t, gate.Enabled("feature")) + + fakeT := &ignoreFatalT{T: t} + defer SetFeatureGateDuringTest(fakeT, gate, "feature", false)() + assert.True(t, fakeT.fatalRecorded) + }) +} + +func TestCannotDetectLeakFromSubtest(t *gotest.T) { + t.Cleanup(func() { + featureFlagOverride = map[featuregate.Feature]string{} + }) + gate := featuregate.NewFeatureGate() + err := gate.Add(map[featuregate.Feature]featuregate.FeatureSpec{ + "feature": {PreRelease: featuregate.Alpha, Default: false}, + }) + require.NoError(t, err) + + defer SetFeatureGateDuringTest(t, gate, "feature", false)() + // Subtest setting feature gate and calling parallel will leak it out + t.Run("Subtest", func(t *gotest.T) { + defer SetFeatureGateDuringTest(t, gate, "feature", true)() + t.Parallel() + }) + // Leaked true + assert.True(t, gate.Enabled("feature")) +} + +type ignoreFatalT struct { + *gotest.T + fatalRecorded bool +} + +func (f *ignoreFatalT) Fatal(args ...any) { + f.T.Helper() + f.fatalRecorded = true + newArgs := []any{"[IGNORED]"} + newArgs = append(newArgs, args...) + f.T.Log(newArgs...) +} + +func (f *ignoreFatalT) Fatalf(format string, args ...any) { + f.T.Helper() + f.fatalRecorded = true + f.T.Logf("[IGNORED] "+format, args...) +}