Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add set state in Go #22919

Merged
merged 4 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sdks/go/pkg/beam/core/graph/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1274,10 +1274,10 @@ func validateState(fn *DoFn, numIn mainInputs) error {
"unique per DoFn", k, orig, s)
}
t := s.StateType()
if t != state.TypeValue && t != state.TypeBag && t != state.TypeCombining && t != state.TypeMap {
if t != state.TypeValue && t != state.TypeBag && t != state.TypeCombining && t != state.TypeSet && t != state.TypeMap {
err := errors.Errorf("Unrecognized state type %v for state %v", t, s)
return errors.SetTopLevelMsgf(err, "Unrecognized state type %v for state %v. Currently the only supported state"+
"types are state.Value, state.Combining, state.Bag, and state.Map", t, s)
"types are state.Value, state.Combining, state.Bag, state.Set, and state.Map", t, s)
}
stateKeys[k] = s
}
Expand Down
9 changes: 9 additions & 0 deletions sdks/go/pkg/beam/core/graph/fn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func TestNewDoFn(t *testing.T) {
return a * b
})}, opt: NumMainInputs(MainKv)},
{dfn: &GoodStatefulDoFn4{State1: state.MakeMapState[string, int]("state1")}, opt: NumMainInputs(MainKv)},
{dfn: &GoodStatefulDoFn5{State1: state.MakeSetState[string]("state1")}, opt: NumMainInputs(MainKv)},
}

for _, test := range tests {
Expand Down Expand Up @@ -1116,6 +1117,14 @@ func (fn *GoodStatefulDoFn4) ProcessElement(state.Provider, int, int) int {
return 0
}

type GoodStatefulDoFn5 struct {
State1 state.Set[string]
}

func (fn *GoodStatefulDoFn5) ProcessElement(state.Provider, int, int) int {
return 0
}

// Examples of incorrect SDF signatures.
// Examples with missing methods.

Expand Down
18 changes: 14 additions & 4 deletions sdks/go/pkg/beam/core/runtime/exec/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,10 +495,21 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
} else if ms := spec.GetMapSpec(); ms != nil {
cID = ms.ValueCoderId
kcID = ms.KeyCoderId
} else if ss := spec.GetSetSpec(); ss != nil {
kcID = ss.ElementCoderId
} else {
return nil, errors.Errorf("Unrecognized state type %v", spec)
}
c, err := b.coders.Coder(cID)
if err != nil {
return nil, err
if cID != "" {
c, err := b.coders.Coder(cID)
if err != nil {
return nil, err
}
stateIDToCoder[key] = c
} else {
// If no value coder is provided, we are in a keyed state with no values (aka a set).
// We represent a set as an element mapping to a bool representing if it is present or not.
stateIDToCoder[key] = &coder.Coder{Kind: coder.Bool}
}
if kcID != "" {
kc, err := b.coders.Coder(kcID)
Expand All @@ -507,7 +518,6 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
}
stateIDToKeyCoder[key] = kc
}
stateIDToCoder[key] = c
sid := StreamID{
Port: Port{URL: b.desc.GetStateApiServiceDescriptor().GetUrl()},
PtransformID: id.to,
Expand Down
25 changes: 20 additions & 5 deletions sdks/go/pkg/beam/core/runtime/graphx/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,18 +467,22 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
m.requirements[URNRequiresStatefulProcessing] = true
stateSpecs := make(map[string]*pipepb.StateSpec)
for _, ps := range edge.Edge.DoFn.PipelineState() {
coderID, err := m.coders.Add(edge.Edge.StateCoders[UserStateCoderId(ps)])
if err != nil {
return handleErr(err)
coderID := ""
c, ok := edge.Edge.StateCoders[UserStateCoderId(ps)]
if ok {
coderID, err = m.coders.Add(c)
if err != nil {
return handleErr(err)
}
}
keyCoderID := ""
if c, ok := edge.Edge.StateCoders[UserStateKeyCoderId(ps)]; ok {
keyCoderID, err = m.coders.Add(c)
if err != nil {
return handleErr(err)
}
} else if ps.StateType() == state.TypeMap {
return nil, errors.Errorf("Map type %v must have a key coder type, none detected", ps)
} else if ps.StateType() == state.TypeMap || ps.StateType() == state.TypeSet {
return nil, errors.Errorf("set or map state type %v must have a key coder type, none detected", ps)
}
switch ps.StateType() {
case state.TypeValue:
Expand Down Expand Up @@ -547,6 +551,17 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
Urn: URNMultiMapUserState,
},
}
case state.TypeSet:
stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
Spec: &pipepb.StateSpec_SetSpec{
SetSpec: &pipepb.SetStateSpec{
ElementCoderId: keyCoderID,
},
},
Protocol: &pipepb.FunctionSpec{
Urn: URNMultiMapUserState,
},
}
default:
return nil, errors.Errorf("State type %v not recognized for state %v", ps.StateKey(), ps)
}
Expand Down
125 changes: 125 additions & 0 deletions sdks/go/pkg/beam/core/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ const (
TypeCombining TypeEnum = 2
// TypeMap represents a map state
TypeMap TypeEnum = 3
// TypeSet represents a set state
TypeSet TypeEnum = 4
)

var (
Expand Down Expand Up @@ -517,3 +519,126 @@ func MakeMapState[K comparable, V any](k string) Map[K, V] {
Key: k,
}
}

// Set is used to read and write global pipeline state representing a Set.
// Key represents the key used to lookup this state (not the key of Set entries).
type Set[K comparable] struct {
Key string
}

// Add is used to write a key to this instance of global Set state.
func (s *Set[K]) Add(p Provider, key K) error {
return p.WriteMapState(Transaction{
Key: s.Key,
Type: TransactionTypeSet,
MapKey: key,
Val: true,
})
}

// Keys is used to read the keys of this set state.
// When a value is not found, returns an empty list and false.
func (s *Set[K]) Keys(p Provider) ([]K, bool, error) {
// This replays any writes that have happened to this value since we last read
// For more detail, see "State Transactionality" below for buffered transactions
initialValue, bufferedTransactions, err := p.ReadMapStateKeys(s.Key)
if err != nil {
return []K{}, false, err
}
cur := []K{}
for _, v := range initialValue {
cur = append(cur, v.(K))
}
for _, t := range bufferedTransactions {
switch t.Type {
case TransactionTypeSet:
seen := false
mk := t.MapKey.(K)
for _, k := range cur {
if k == mk {
seen = true
}
}
if !seen {
cur = append(cur, mk)
}
case TransactionTypeClear:
if t.MapKey == nil {
cur = []K{}
} else {
k := t.MapKey.(K)
for idx, v := range cur {
if v == k {
// Remove this key since its been cleared
cur[idx] = cur[len(cur)-1]
cur = cur[:len(cur)-1]
break
}
}
}
}
}
if len(cur) == 0 {
return cur, false, nil
}
return cur, true, nil
}

// Contains is used to determine if a given a key exists in the set.
func (s *Set[K]) Contains(p Provider, key K) (bool, error) {
// This replays any writes that have happened to this value since we last read
// For more detail, see "State Transactionality" below for buffered transactions
cur, bufferedTransactions, err := p.ReadMapStateValue(s.Key, key)
if err != nil {
return false, err
}
for _, t := range bufferedTransactions {
switch t.Type {
case TransactionTypeSet:
if t.MapKey.(K) == key {
cur = t.Val
}
case TransactionTypeClear:
if t.MapKey == nil || t.MapKey.(K) == key {
cur = nil
}
}
}
if cur == nil {
return false, nil
}
return true, nil
}

// StateKey returns the key for this pipeline state entry.
func (s Set[K]) StateKey() string {
if s.Key == "" {
// TODO(#22736) - infer the state from the member variable name during pipeline construction.
panic("Value state exists on struct but has not been initialized with a key.")
}
return s.Key
}

// KeyCoderType returns the type of the value state which should be used for a coder for set keys.
func (s Set[K]) KeyCoderType() reflect.Type {
var k K
return reflect.TypeOf(k)
}

// CoderType returns the type of the coder used for values, in this case nil since there are no values associated with a set.
func (s Set[K]) CoderType() reflect.Type {
// A bool coder is used later, but it does not need to be passed around or visible to users.
return nil
}

// StateType returns the type of the state (in this case always Set).
func (s Set[K]) StateType() TypeEnum {
return TypeSet
}

// MakeSetState is a factory function to create an instance of SetState with the given key.
func MakeSetState[K comparable](k string) Set[K] {
return Set[K]{
Key: k,
}
}
Loading