diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..5acc3a9 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,22 @@ +name: CI + +on: push + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + name: Run tests + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.22" + - name: Run test script + run: | + chmod +x ./scripts/test.sh && ./scripts/test.sh diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..616978e --- /dev/null +++ b/.gitignore @@ -0,0 +1,27 @@ +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work +go.work.sum + +# env file +.env + +.idea \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..06aa68a --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module github.com/nejdetkadir/statemachine + +go 1.22.2 + +require ( + github.com/jedib0t/go-pretty/v6 v6.5.9 + github.com/stretchr/testify v1.8.4 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rivo/uniseg v0.2.0 // indirect + golang.org/x/sys v0.17.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6d90beb --- /dev/null +++ b/go.sum @@ -0,0 +1,18 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jedib0t/go-pretty/v6 v6.5.9 h1:ACteMBRrrmm1gMsXe9PSTOClQ63IXDUt03H5U+UV8OU= +github.com/jedib0t/go-pretty/v6 v6.5.9/go.mod h1:zbn98qrYlh95FIhwwsbIip0LYpwSG8SUOScs+v9/t0E= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100755 index 0000000..aec7715 --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +output=$(go test ./...) +result=$? + +if [ $result -ne 0 ]; then + echo "Some tests failed:" + echo "$output" + exit 1 +else + go test ./... -v + exit 0 +fi diff --git a/statemachine.go b/statemachine.go new file mode 100644 index 0000000..453109a --- /dev/null +++ b/statemachine.go @@ -0,0 +1,178 @@ +package statemachine + +import ( + "errors" + "fmt" + "github.com/jedib0t/go-pretty/v6/table" + "os" + "slices" +) + +type ( + Context struct { + states []string + initialState string + currentState string + events []Event + beforeAll func(event string, from string, to string) + afterAll func(event string, from string, to string) + } + Event struct { + name string + to string + from []string + before func() + after func() + validate func(from string, to string) error + } + StateMachine interface { + CurrentState() string + Fire(event string) error + RegisterEvent(event Event) error + RegisterEvents(events []Event) error + BeforeAll(before func(event string, from string, to string)) + AfterAll(after func(event string, from string, to string)) + RenderGraph() + Context() *Context + SetCurrentState(state string) error + } +) + +func New(states []string, initialState string) (StateMachine, error) { + if slices.Contains(states, initialState) == false { + return nil, errors.New("initial state must be one of the states") + } + + return &Context{ + states: states, + initialState: initialState, + currentState: initialState, + }, nil +} + +func (c *Context) CurrentState() string { + return c.currentState +} + +func (c *Context) Context() *Context { + return c +} + +func (c *Context) RegisterEvent(event Event) error { + if slices.Contains(c.states, event.to) == false { + return errors.New(fmt.Sprintf("to state must be one of: %v", c.states)) + } + + if slices.ContainsFunc(event.from, func(s string) bool { + return slices.Contains(c.states, s) == false + }) { + return errors.New(fmt.Sprintf("from states must be one of: %v", c.states)) + } + + if slices.Contains(event.from, event.to) { + return errors.New("from and to states cannot be the same") + } + + if slices.ContainsFunc(c.events, func(e Event) bool { + return e.name == event.name + }) { + return errors.New("event name must be unique") + } + + c.events = append(c.events, event) + + return nil +} + +func (c *Context) RegisterEvents(events []Event) error { + var err error + + for _, e := range events { + err = c.RegisterEvent(e) + } + + if err != nil { + c.events = []Event{} + + return err + } + + return nil +} + +func (c *Context) BeforeAll(before func(event string, from string, to string)) { + c.beforeAll = before +} + +func (c *Context) AfterAll(after func(event string, from string, to string)) { + c.afterAll = after +} + +func (c *Context) RenderGraph() { + t := table.NewWriter() + t.SetOutputMirror(os.Stdout) + t.AppendHeader(table.Row{"Event", "From", "To"}) + + for _, e := range c.events { + t.AppendRow([]interface{}{e.name, e.from, e.to}) + } + + t.Render() +} + +func (c *Context) Fire(event string) error { + var currentEvent *Event + + for _, e := range c.events { + if e.name == event { + currentEvent = &e + break + } + } + + if currentEvent == nil { + return errors.New(fmt.Sprintf("%s event is not registered", event)) + } + + if slices.Contains(currentEvent.from, c.currentState) == false { + return errors.New(fmt.Sprintf("cannot fire the %s event from the %s state", currentEvent.name, c.currentState)) + } + + if c.beforeAll != nil { + c.beforeAll(currentEvent.name, c.currentState, currentEvent.to) + } + + if currentEvent.before != nil { + currentEvent.before() + } + + if currentEvent.validate != nil { + err := currentEvent.validate(c.currentState, currentEvent.to) + + if err != nil { + return err + } + } + + c.currentState = currentEvent.to + + if currentEvent.after != nil { + currentEvent.after() + } + + if c.afterAll != nil { + c.afterAll(currentEvent.name, c.currentState, currentEvent.to) + } + + return nil +} + +func (c *Context) SetCurrentState(state string) error { + if !slices.Contains(c.states, state) { + return errors.New(fmt.Sprintf("state must be one of: %v", c.states)) + } + + c.currentState = state + + return nil +} diff --git a/statemachine_test.go b/statemachine_test.go new file mode 100644 index 0000000..7541a71 --- /dev/null +++ b/statemachine_test.go @@ -0,0 +1,509 @@ +package statemachine + +import ( + "errors" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNew(t *testing.T) { + t.Run("should return error if initial state is not in states", func(t *testing.T) { + _, err := New([]string{"A", "B"}, "C") + + assert.Error(t, err) + assert.Equal(t, "initial state must be one of the states", err.Error()) + }) + + t.Run("should return a new state machine", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + sm, err := New(states, initialState) + + assert.NoError(t, err) + assert.NotNil(t, sm) + assert.Equal(t, states, sm.Context().states) + assert.Equal(t, initialState, sm.CurrentState()) + }) +} + +func TestStateMachine_CurrentState(t *testing.T) { + t.Run("should return the current state", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + sm, err := New(states, initialState) + + assert.NoError(t, err) + assert.Equal(t, initialState, sm.CurrentState()) + }) + + t.Run("should return the current state after firing an event", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{name: "event", from: []string{"A"}, to: "B"}) + + assert.NoError(t, err) + + err = sm.Fire("event") + + assert.NoError(t, err) + assert.Equal(t, "B", sm.CurrentState()) + }) +} + +func TestStateMachine_Context(t *testing.T) { + t.Run("should return the context", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + sm, err := New(states, initialState) + + assert.NoError(t, err) + assert.Equal(t, sm.Context(), sm.(*Context)) + assert.Equal(t, states, sm.Context().states) + assert.Equal(t, initialState, sm.Context().currentState) + assert.Empty(t, sm.Context().events) + assert.Nil(t, sm.Context().beforeAll) + assert.Nil(t, sm.Context().afterAll) + }) +} + +func TestStateMachine_RegisterEvent(t *testing.T) { + t.Run("should register an event", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{name: "event", from: []string{"A"}, to: "B"}) + + assert.NoError(t, err) + assert.Len(t, sm.Context().events, 1) + assert.Equal(t, "event", sm.Context().events[0].name) + assert.Equal(t, []string{"A"}, sm.Context().events[0].from) + assert.Equal(t, "B", sm.Context().events[0].to) + assert.Nil(t, sm.Context().events[0].before) + assert.Nil(t, sm.Context().events[0].after) + }) + + t.Run("should return error if to state is not in states", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{name: "event", from: []string{"A"}, to: "C"}) + + assert.Error(t, err) + assert.Equal(t, "to state must be one of: [A B]", err.Error()) + }) + + t.Run("should return error if from states are not in states", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{name: "event", from: []string{"A", "C"}, to: "B"}) + + assert.Error(t, err) + assert.Equal(t, "from states must be one of: [A B]", err.Error()) + }) + + t.Run("should return error if from and to states are the same", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{name: "event", from: []string{"A"}, to: "A"}) + + assert.Error(t, err) + assert.Equal(t, "from and to states cannot be the same", err.Error()) + }) + + t.Run("should return error if event name is not unique", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{name: "event", from: []string{"A"}, to: "B"}) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{name: "event", from: []string{"A"}, to: "B"}) + + assert.Error(t, err) + assert.Equal(t, "event name must be unique", err.Error()) + }) +} + +func TestStateMachine_RegisterEvents(t *testing.T) { + t.Run("should register multiple events", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvents([]Event{ + { + name: "event1", + from: []string{"A"}, + to: "B", + }, + { + name: "event2", + from: []string{"B"}, + to: "A", + }, + }) + + assert.NoError(t, err) + assert.Len(t, sm.Context().events, 2) + assert.Equal(t, "event1", sm.Context().events[0].name) + assert.Equal(t, "event2", sm.Context().events[1].name) + }) + + t.Run("should return error if one of the events is invalid", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvents([]Event{ + { + name: "event1", + from: []string{"A"}, + to: "B", + }, + { + name: "event2", + from: []string{"C"}, + to: "A", + }, + }) + + assert.Error(t, err) + assert.Equal(t, "from states must be one of: [A B]", err.Error()) + assert.Empty(t, sm.Context().events) + }) + + t.Run("should return error if one of the events is invalid and clear all events", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvents([]Event{ + { + name: "event1", + from: []string{"A"}, + to: "B", + }, + { + name: "event2", + from: []string{"C"}, + to: "A", + }, + }) + + assert.Error(t, err) + assert.Equal(t, "from states must be one of: [A B]", err.Error()) + assert.Empty(t, sm.Context().events) + }) +} + +func TestStateMachine_BeforeAll(t *testing.T) { + t.Run("should run before all function before firing an event", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + beforeAll := false + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + sm.BeforeAll(func(event string, from string, to string) { + beforeAll = true + }) + + err = sm.RegisterEvent(Event{name: "event", from: []string{"A"}, to: "B"}) + + assert.NoError(t, err) + + err = sm.Fire("event") + + assert.NoError(t, err) + assert.True(t, beforeAll) + }) +} + +func TestStateMachine_AfterAll(t *testing.T) { + t.Run("should run after all function after firing an event", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + afterAll := false + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + sm.AfterAll(func(event string, from string, to string) { + afterAll = true + }) + + err = sm.RegisterEvent(Event{name: "event", from: []string{"A"}, to: "B"}) + + assert.NoError(t, err) + + err = sm.Fire("event") + + assert.NoError(t, err) + assert.True(t, afterAll) + }) +} + +func TestStateMachine_RenderGraph(t *testing.T) { + t.Run("should render the graph", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{name: "event", from: []string{"A"}, to: "B"}) + + assert.NoError(t, err) + + sm.RenderGraph() + }) + + t.Run("should render the graph with multiple events", func(t *testing.T) { + states := []string{"A", "B", "C"} + initialState := "A" + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvents([]Event{ + { + name: "event1", + from: []string{"A"}, + to: "B", + }, + { + name: "event2", + from: []string{"B"}, + to: "C", + }, + }) + + assert.NoError(t, err) + + sm.RenderGraph() + }) +} + +func TestStateMachine_Fire(t *testing.T) { + t.Run("should return error if event name is not registered", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.Fire("test1") + + assert.Error(t, err) + assert.Equal(t, "test1 event is not registered", err.Error()) + }) + + t.Run("should return error if event is not allowed in the current state", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{name: "test1", from: []string{"B"}, to: "A"}) + + assert.NoError(t, err) + + err = sm.Fire("test1") + + assert.Error(t, err) + assert.Equal(t, "cannot fire the test1 event from the A state", err.Error()) + }) + + t.Run("should return error if validate function returns an error", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{ + name: "test1", + from: []string{"A"}, + to: "B", + validate: func(from string, to string) error { + return errors.New("test1 event is not allowed") + }, + }) + + assert.NoError(t, err) + + err = sm.Fire("test1") + + assert.Error(t, err) + assert.Equal(t, "test1 event is not allowed", err.Error()) + }) + + t.Run("should run before function before firing an event", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + before := false + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{ + name: "test1", + from: []string{"A"}, + to: "B", + before: func() { + before = true + }, + }) + + assert.NoError(t, err) + + err = sm.Fire("test1") + + assert.NoError(t, err) + assert.True(t, before) + }) + + t.Run("should run after function after firing an event", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + after := false + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{ + name: "test1", + from: []string{"A"}, + to: "B", + after: func() { + after = true + }, + }) + + assert.NoError(t, err) + + err = sm.Fire("test1") + + assert.NoError(t, err) + assert.True(t, after) + }) + + t.Run("should run before and after functions before and after firing an event", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + before := false + after := false + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{ + name: "test1", + from: []string{"A"}, + to: "B", + before: func() { + before = true + }, + after: func() { + after = true + }, + }) + + assert.NoError(t, err) + + err = sm.Fire("test1") + + assert.NoError(t, err) + assert.True(t, before) + assert.True(t, after) + }) + + t.Run("should change the current state after firing an event", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.RegisterEvent(Event{name: "test1", from: []string{"A"}, to: "B"}) + + assert.NoError(t, err) + + err = sm.Fire("test1") + + assert.NoError(t, err) + assert.Equal(t, "B", sm.CurrentState()) + }) +} + +func TestStateMachine_SetCurrentState(t *testing.T) { + t.Run("should return error if state is not in states", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.SetCurrentState("C") + + assert.Error(t, err) + assert.Equal(t, "state must be one of: [A B]", err.Error()) + }) + + t.Run("should set the current state", func(t *testing.T) { + states := []string{"A", "B"} + initialState := "A" + + sm, err := New(states, initialState) + + assert.NoError(t, err) + + err = sm.SetCurrentState("B") + + assert.NoError(t, err) + assert.Equal(t, "B", sm.CurrentState()) + }) +}