Skip to content

Commit

Permalink
Use generics to stub a single value with auto-reset (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
prashantv authored Oct 21, 2024
1 parent 72a167a commit c8bd471
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
1 change: 1 addition & 0 deletions gostub.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
// Stub replaces the value stored at varToStub with stubVal.
// varToStub must be a pointer to the variable. stubVal should have a type
// that is assignable to the variable.
// When stubbing a single value, prefer `Value`.
func Stub(varToStub interface{}, stubVal interface{}) *Stubs {
return New().Stub(varToStub, stubVal)
}
Expand Down
30 changes: 27 additions & 3 deletions gostub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,38 @@ func TestStub(t *testing.T) {

stubs := Stub(&v1, 1)

if v1 != 1 {
t.Errorf("expected")
}
expectVal(t, v1, 1)
stubs.Reset()
expectVal(t, v1, 100)
}

func TestValue(t *testing.T) {
resetVars()

t.Run("test", func(t *testing.T) {
reset1 := Value(t, &v1, 1)
reset2 := Value(t, &v2, 2)
expectVal(t, v1, 1)
expectVal(t, v2, 2)
reset1()
expectVal(t, v1, 100)
expectVal(t, v2, 2)
reset2()
expectVal(t, v1, 100)
expectVal(t, v2, 200)

Value(t, &v1, 0)
Value(t, &v2, 0)
Value(t, &v3, 0)
})

t.Run("verify Cleanup", func(t *testing.T) {
expectVal(t, v1, 100)
expectVal(t, v2, 200)
expectVal(t, v3, 300)
})
}

func TestRestub(t *testing.T) {
resetVars()

Expand Down
20 changes: 20 additions & 0 deletions value.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package gostub

// TestingT is a subset of the testing.TB interface used by gostub.
type TestingT interface {
Cleanup(func())
}

// Value replaces the value at varPtr with stubVal.
// The original value is reset at the end of the test via t.Cleanup
// or can be reset using the returned function.
func Value[T any](t TestingT, varPtr *T, stubVal T) (reset func()) {
orig := *varPtr
*varPtr = stubVal

reset = func() {
*varPtr = orig
}
t.Cleanup(reset)
return reset
}

0 comments on commit c8bd471

Please sign in to comment.