Skip to content

Commit

Permalink
Add set state in Go (#22919)
Browse files Browse the repository at this point in the history
* Add set state in Go

* Update sdks/go/pkg/beam/core/state/state_test.go

Co-authored-by: Ritesh Ghorse <riteshghorse@gmail.com>

* Update sdks/go/pkg/beam/core/state/state_test.go

Co-authored-by: Ritesh Ghorse <riteshghorse@gmail.com>

* Remove unneccessary conversion

Co-authored-by: Ritesh Ghorse <riteshghorse@gmail.com>
  • Loading branch information
damccorm and riteshghorse authored Aug 29, 2022
1 parent e9089dd commit 4a66829
Show file tree
Hide file tree
Showing 7 changed files with 345 additions and 16 deletions.
4 changes: 2 additions & 2 deletions sdks/go/pkg/beam/core/graph/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1274,10 +1274,10 @@ func validateState(fn *DoFn, numIn mainInputs) error {
"unique per DoFn", k, orig, s)
}
t := s.StateType()
if t != state.TypeValue && t != state.TypeBag && t != state.TypeCombining && t != state.TypeMap {
if t != state.TypeValue && t != state.TypeBag && t != state.TypeCombining && t != state.TypeSet && t != state.TypeMap {
err := errors.Errorf("Unrecognized state type %v for state %v", t, s)
return errors.SetTopLevelMsgf(err, "Unrecognized state type %v for state %v. Currently the only supported state"+
"types are state.Value, state.Combining, state.Bag, and state.Map", t, s)
"types are state.Value, state.Combining, state.Bag, state.Set, and state.Map", t, s)
}
stateKeys[k] = s
}
Expand Down
9 changes: 9 additions & 0 deletions sdks/go/pkg/beam/core/graph/fn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func TestNewDoFn(t *testing.T) {
return a * b
})}, opt: NumMainInputs(MainKv)},
{dfn: &GoodStatefulDoFn4{State1: state.MakeMapState[string, int]("state1")}, opt: NumMainInputs(MainKv)},
{dfn: &GoodStatefulDoFn5{State1: state.MakeSetState[string]("state1")}, opt: NumMainInputs(MainKv)},
}

for _, test := range tests {
Expand Down Expand Up @@ -1116,6 +1117,14 @@ func (fn *GoodStatefulDoFn4) ProcessElement(state.Provider, int, int) int {
return 0
}

type GoodStatefulDoFn5 struct {
State1 state.Set[string]
}

func (fn *GoodStatefulDoFn5) ProcessElement(state.Provider, int, int) int {
return 0
}

// Examples of incorrect SDF signatures.
// Examples with missing methods.

Expand Down
18 changes: 14 additions & 4 deletions sdks/go/pkg/beam/core/runtime/exec/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,10 +495,21 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
} else if ms := spec.GetMapSpec(); ms != nil {
cID = ms.ValueCoderId
kcID = ms.KeyCoderId
} else if ss := spec.GetSetSpec(); ss != nil {
kcID = ss.ElementCoderId
} else {
return nil, errors.Errorf("Unrecognized state type %v", spec)
}
c, err := b.coders.Coder(cID)
if err != nil {
return nil, err
if cID != "" {
c, err := b.coders.Coder(cID)
if err != nil {
return nil, err
}
stateIDToCoder[key] = c
} else {
// If no value coder is provided, we are in a keyed state with no values (aka a set).
// We represent a set as an element mapping to a bool representing if it is present or not.
stateIDToCoder[key] = &coder.Coder{Kind: coder.Bool}
}
if kcID != "" {
kc, err := b.coders.Coder(kcID)
Expand All @@ -507,7 +518,6 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
}
stateIDToKeyCoder[key] = kc
}
stateIDToCoder[key] = c
sid := StreamID{
Port: Port{URL: b.desc.GetStateApiServiceDescriptor().GetUrl()},
PtransformID: id.to,
Expand Down
25 changes: 20 additions & 5 deletions sdks/go/pkg/beam/core/runtime/graphx/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,18 +467,22 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
m.requirements[URNRequiresStatefulProcessing] = true
stateSpecs := make(map[string]*pipepb.StateSpec)
for _, ps := range edge.Edge.DoFn.PipelineState() {
coderID, err := m.coders.Add(edge.Edge.StateCoders[UserStateCoderId(ps)])
if err != nil {
return handleErr(err)
coderID := ""
c, ok := edge.Edge.StateCoders[UserStateCoderId(ps)]
if ok {
coderID, err = m.coders.Add(c)
if err != nil {
return handleErr(err)
}
}
keyCoderID := ""
if c, ok := edge.Edge.StateCoders[UserStateKeyCoderId(ps)]; ok {
keyCoderID, err = m.coders.Add(c)
if err != nil {
return handleErr(err)
}
} else if ps.StateType() == state.TypeMap {
return nil, errors.Errorf("Map type %v must have a key coder type, none detected", ps)
} else if ps.StateType() == state.TypeMap || ps.StateType() == state.TypeSet {
return nil, errors.Errorf("set or map state type %v must have a key coder type, none detected", ps)
}
switch ps.StateType() {
case state.TypeValue:
Expand Down Expand Up @@ -547,6 +551,17 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
Urn: URNMultiMapUserState,
},
}
case state.TypeSet:
stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
Spec: &pipepb.StateSpec_SetSpec{
SetSpec: &pipepb.SetStateSpec{
ElementCoderId: keyCoderID,
},
},
Protocol: &pipepb.FunctionSpec{
Urn: URNMultiMapUserState,
},
}
default:
return nil, errors.Errorf("State type %v not recognized for state %v", ps.StateKey(), ps)
}
Expand Down
125 changes: 125 additions & 0 deletions sdks/go/pkg/beam/core/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ const (
TypeCombining TypeEnum = 2
// TypeMap represents a map state
TypeMap TypeEnum = 3
// TypeSet represents a set state
TypeSet TypeEnum = 4
)

var (
Expand Down Expand Up @@ -517,3 +519,126 @@ func MakeMapState[K comparable, V any](k string) Map[K, V] {
Key: k,
}
}

// Set is used to read and write global pipeline state representing a Set.
// Key represents the key used to lookup this state (not the key of Set entries).
type Set[K comparable] struct {
Key string
}

// Add is used to write a key to this instance of global Set state.
func (s *Set[K]) Add(p Provider, key K) error {
return p.WriteMapState(Transaction{
Key: s.Key,
Type: TransactionTypeSet,
MapKey: key,
Val: true,
})
}

// Keys is used to read the keys of this set state.
// When a value is not found, returns an empty list and false.
func (s *Set[K]) Keys(p Provider) ([]K, bool, error) {
// This replays any writes that have happened to this value since we last read
// For more detail, see "State Transactionality" below for buffered transactions
initialValue, bufferedTransactions, err := p.ReadMapStateKeys(s.Key)
if err != nil {
return []K{}, false, err
}
cur := []K{}
for _, v := range initialValue {
cur = append(cur, v.(K))
}
for _, t := range bufferedTransactions {
switch t.Type {
case TransactionTypeSet:
seen := false
mk := t.MapKey.(K)
for _, k := range cur {
if k == mk {
seen = true
}
}
if !seen {
cur = append(cur, mk)
}
case TransactionTypeClear:
if t.MapKey == nil {
cur = []K{}
} else {
k := t.MapKey.(K)
for idx, v := range cur {
if v == k {
// Remove this key since its been cleared
cur[idx] = cur[len(cur)-1]
cur = cur[:len(cur)-1]
break
}
}
}
}
}
if len(cur) == 0 {
return cur, false, nil
}
return cur, true, nil
}

// Contains is used to determine if a given a key exists in the set.
func (s *Set[K]) Contains(p Provider, key K) (bool, error) {
// This replays any writes that have happened to this value since we last read
// For more detail, see "State Transactionality" below for buffered transactions
cur, bufferedTransactions, err := p.ReadMapStateValue(s.Key, key)
if err != nil {
return false, err
}
for _, t := range bufferedTransactions {
switch t.Type {
case TransactionTypeSet:
if t.MapKey.(K) == key {
cur = t.Val
}
case TransactionTypeClear:
if t.MapKey == nil || t.MapKey.(K) == key {
cur = nil
}
}
}
if cur == nil {
return false, nil
}
return true, nil
}

// StateKey returns the key for this pipeline state entry.
func (s Set[K]) StateKey() string {
if s.Key == "" {
// TODO(#22736) - infer the state from the member variable name during pipeline construction.
panic("Value state exists on struct but has not been initialized with a key.")
}
return s.Key
}

// KeyCoderType returns the type of the value state which should be used for a coder for set keys.
func (s Set[K]) KeyCoderType() reflect.Type {
var k K
return reflect.TypeOf(k)
}

// CoderType returns the type of the coder used for values, in this case nil since there are no values associated with a set.
func (s Set[K]) CoderType() reflect.Type {
// A bool coder is used later, but it does not need to be passed around or visible to users.
return nil
}

// StateType returns the type of the state (in this case always Set).
func (s Set[K]) StateType() TypeEnum {
return TypeSet
}

// MakeSetState is a factory function to create an instance of SetState with the given key.
func MakeSetState[K comparable](k string) Set[K] {
return Set[K]{
Key: k,
}
}
Loading

0 comments on commit 4a66829

Please sign in to comment.