Skip to content

Commit d0821c2

Browse files
committed
Add wrapper to store and retrieve values in context
1 parent 24ec2f9 commit d0821c2

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

context.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package pipeline
2+
3+
import (
4+
"context"
5+
"errors"
6+
"sync"
7+
)
8+
9+
type contextKey struct{}
10+
11+
// VariableContext adds a map to the given context that can be used to store intermediate values in the context.
12+
// It uses sync.Map under the hood.
13+
//
14+
// See also AddToContext() and ValueFromContext.
15+
func VariableContext(parent context.Context) context.Context {
16+
return context.WithValue(parent, contextKey{}, &sync.Map{})
17+
}
18+
19+
// AddToContext adds the given key and value to ctx.
20+
// Any keys or values added during pipeline execution is available in the next steps, provided the pipeline runs synchronously.
21+
// In parallel executed pipelines you may encounter race conditions.
22+
// Use ValueFromContext to retrieve values.
23+
//
24+
// Note: This method is thread-safe, but panics if ctx has not been set up with VariableContext first.
25+
func AddToContext(ctx context.Context, key, value interface{}) {
26+
m := ctx.Value(contextKey{})
27+
if m == nil {
28+
panic(errors.New("context was not set up with VariableContext()"))
29+
}
30+
m.(*sync.Map).Store(key, value)
31+
}
32+
33+
// ValueFromContext returns the value from the given context with the given key.
34+
// It returns the value and true, or nil and false if the key doesn't exist.
35+
// It may return nil and true if the key exists, but the value actually is nil.
36+
// Use AddToContext to store values.
37+
//
38+
// Note: This method is thread-safe, but panics if the ctx has not been set up with VariableContext first.
39+
func ValueFromContext(ctx context.Context, key interface{}) (interface{}, bool) {
40+
m := ctx.Value(contextKey{})
41+
if m == nil {
42+
panic(errors.New("context was not set up with VariableContext()"))
43+
}
44+
mp := m.(*sync.Map)
45+
val, found := mp.Load(key)
46+
return val, found
47+
}

context_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package pipeline
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestContext(t *testing.T) {
12+
tests := map[string]struct {
13+
givenKey interface{}
14+
givenValue interface{}
15+
expectedValue interface{}
16+
expectedFound bool
17+
}{
18+
"GivenNonExistentKey_ThenExpectNilAndFalse": {
19+
givenKey: nil,
20+
expectedValue: nil,
21+
},
22+
"GivenKeyWithNilValue_ThenExpectNilAndTrue": {
23+
givenKey: "key",
24+
givenValue: nil,
25+
expectedValue: nil,
26+
expectedFound: true,
27+
},
28+
"GivenKeyWithValue_ThenExpectValueAndTrue": {
29+
givenKey: "key",
30+
givenValue: "value",
31+
expectedValue: "value",
32+
expectedFound: true,
33+
},
34+
}
35+
for name, tc := range tests {
36+
t.Run(name, func(t *testing.T) {
37+
ctx := VariableContext(context.Background())
38+
if tc.givenKey != nil {
39+
AddToContext(ctx, tc.givenKey, tc.givenValue)
40+
}
41+
result, found := ValueFromContext(ctx, tc.givenKey)
42+
assert.Equal(t, tc.expectedValue, result, "value")
43+
assert.Equal(t, tc.expectedFound, found, "value found")
44+
})
45+
}
46+
}
47+
48+
func TestContextPanics(t *testing.T) {
49+
assert.PanicsWithError(t, "context was not set up with VariableContext()", func() {
50+
AddToContext(context.Background(), "key", "value")
51+
}, "AddToContext")
52+
assert.PanicsWithError(t, "context was not set up with VariableContext()", func() {
53+
ValueFromContext(context.Background(), "key")
54+
}, "ValueFromContext")
55+
}
56+
57+
func ExampleVariableContext() {
58+
ctx := VariableContext(context.Background())
59+
p := NewPipeline().WithSteps(
60+
NewStepFromFunc("store value", func(ctx context.Context) error {
61+
AddToContext(ctx, "key", "value")
62+
return nil
63+
}),
64+
NewStepFromFunc("retrieve value", func(ctx context.Context) error {
65+
value, _ := ValueFromContext(ctx, "key")
66+
fmt.Println(value)
67+
return nil
68+
}),
69+
)
70+
p.RunWithContext(ctx)
71+
// Output: value
72+
}

0 commit comments

Comments
 (0)