From 104880c26b44648bea6e5670996e7fcdeb9f3be1 Mon Sep 17 00:00:00 2001 From: Ivan Pushkin Date: Sun, 20 Jun 2021 06:39:43 +0300 Subject: [PATCH] Add MVP --- cmd/snake-bot/main.go | 70 +++ embed.go | 6 + go.mod | 18 + go.sum | 167 +++++++ internal/bot/bot.go | 150 +++++++ internal/bot/bot_score.go | 67 +++ internal/bot/engine/area.go | 142 ++++++ internal/bot/engine/area_test.go | 229 ++++++++++ internal/bot/engine/backtrack.go | 15 + internal/bot/engine/backtrack_test.go | 8 + internal/bot/engine/cacher_discoverer.go | 82 ++++ internal/bot/engine/cacher_discoverer_test.go | 85 ++++ internal/bot/engine/dijkstras_discoverer.go | 86 ++++ internal/bot/engine/discoverer.go | 8 + internal/bot/engine/hashmap_sight.go | 94 ++++ .../bot/engine/hashmap_sight_bench_test.go | 151 +++++++ internal/bot/engine/hashmap_sight_test.go | 192 ++++++++ internal/bot/engine/map.go | 83 ++++ internal/bot/engine/queue.go | 63 +++ internal/bot/engine/sight.go | 135 ++++++ internal/bot/engine/sight_test.go | 409 ++++++++++++++++++ internal/bot/game.go | 99 +++++ internal/config/config.go | 178 ++++++++ internal/config/config_test.go | 211 +++++++++ internal/config/std_config.go | 12 + internal/connect/connection.go | 80 ++++ internal/connect/connector.go | 84 ++++ internal/core/bot_operator.go | 129 ++++++ internal/core/core.go | 114 +++++ internal/core/diff.go | 27 ++ internal/core/diff_test.go | 85 ++++ internal/parser/parser.go | 168 +++++++ internal/parser/parser_mocks_test.go | 55 +++ internal/parser/parser_test.go | 39 ++ internal/secure/secure.go | 74 ++++ .../server/handlers/apply_state_handler.go | 131 ++++++ internal/server/handlers/openapi_handler.go | 17 + .../server/handlers/read_state_handler.go | 45 ++ internal/server/handlers/welcome_handler.go | 18 + internal/server/middlewares/logger.go | 68 +++ internal/server/middlewares/token_auth.go | 22 + internal/server/models/error.go | 6 + internal/server/models/game.go | 6 + internal/server/models/games.go | 27 ++ internal/server/server.go | 134 ++++++ internal/types/types.go | 243 +++++++++++ internal/types/types_test.go | 165 +++++++ internal/utils/client_name.go | 7 + internal/utils/context.go | 24 + internal/utils/logger.go | 34 ++ internal/utils/printer.go | 28 ++ 51 files changed, 4590 insertions(+) create mode 100644 cmd/snake-bot/main.go create mode 100644 embed.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/bot/bot.go create mode 100644 internal/bot/bot_score.go create mode 100644 internal/bot/engine/area.go create mode 100644 internal/bot/engine/area_test.go create mode 100644 internal/bot/engine/backtrack.go create mode 100644 internal/bot/engine/backtrack_test.go create mode 100644 internal/bot/engine/cacher_discoverer.go create mode 100644 internal/bot/engine/cacher_discoverer_test.go create mode 100644 internal/bot/engine/dijkstras_discoverer.go create mode 100644 internal/bot/engine/discoverer.go create mode 100644 internal/bot/engine/hashmap_sight.go create mode 100644 internal/bot/engine/hashmap_sight_bench_test.go create mode 100644 internal/bot/engine/hashmap_sight_test.go create mode 100644 internal/bot/engine/map.go create mode 100644 internal/bot/engine/queue.go create mode 100644 internal/bot/engine/sight.go create mode 100644 internal/bot/engine/sight_test.go create mode 100644 internal/bot/game.go create mode 100644 internal/config/config.go create mode 100644 internal/config/config_test.go create mode 100644 internal/config/std_config.go create mode 100644 internal/connect/connection.go create mode 100644 internal/connect/connector.go create mode 100644 internal/core/bot_operator.go create mode 100644 internal/core/core.go create mode 100644 internal/core/diff.go create mode 100644 internal/core/diff_test.go create mode 100644 internal/parser/parser.go create mode 100644 internal/parser/parser_mocks_test.go create mode 100644 internal/parser/parser_test.go create mode 100644 internal/secure/secure.go create mode 100644 internal/server/handlers/apply_state_handler.go create mode 100644 internal/server/handlers/openapi_handler.go create mode 100644 internal/server/handlers/read_state_handler.go create mode 100644 internal/server/handlers/welcome_handler.go create mode 100644 internal/server/middlewares/logger.go create mode 100644 internal/server/middlewares/token_auth.go create mode 100644 internal/server/models/error.go create mode 100644 internal/server/models/game.go create mode 100644 internal/server/models/games.go create mode 100644 internal/server/server.go create mode 100644 internal/types/types.go create mode 100644 internal/types/types_test.go create mode 100644 internal/utils/client_name.go create mode 100644 internal/utils/context.go create mode 100644 internal/utils/logger.go create mode 100644 internal/utils/printer.go diff --git a/cmd/snake-bot/main.go b/cmd/snake-bot/main.go new file mode 100644 index 0000000..731dc36 --- /dev/null +++ b/cmd/snake-bot/main.go @@ -0,0 +1,70 @@ +package main + +import ( + "context" + "math/rand" + "os" + "os/signal" + "time" + + "github.com/sirupsen/logrus" + + "github.com/ivan1993spb/snake-bot/internal/config" + "github.com/ivan1993spb/snake-bot/internal/connect" + "github.com/ivan1993spb/snake-bot/internal/core" + "github.com/ivan1993spb/snake-bot/internal/secure" + "github.com/ivan1993spb/snake-bot/internal/server" + "github.com/ivan1993spb/snake-bot/internal/utils" +) + +const ApplicationName = "Snake-Bot" + +var ( + Version = "dev" + Build = "dev" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func main() { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + cfg, err := config.StdConfig() + + { + logger := utils.NewLogger(cfg.Log) + ctx = utils.LogContext(ctx, logger) + } + + if err != nil { + utils.Log(ctx).WithError(err).Fatal("config fail") + } + + utils.Log(ctx).WithFields(logrus.Fields{ + "version": Version, + "build": Build, + }).Info("Welcome to Snake-Bot!") + + sec := secure.NewSecure() + if err := sec.GenerateToken(os.Stdout); err != nil { + utils.Log(ctx).WithError(err).Fatal("security fail") + } + utils.Log(ctx).Warn("auth token successfully generated") + + headerAppInfo := utils.FormatAppInfoHeader(ApplicationName, Version, Build) + + connector := connect.NewConnector(cfg.Target, headerAppInfo) + + c := core.NewCore(ctx, connector, cfg.Bots.Limit) + + serv := server.NewServer(ctx, cfg.Server, headerAppInfo, c, sec) + + if err := serv.ListenAndServe(ctx); err != nil { + utils.Log(ctx).WithError(err).Fatal("server error") + } + + utils.Log(ctx).Info("buh bye!") +} diff --git a/embed.go b/embed.go new file mode 100644 index 0000000..0aaf407 --- /dev/null +++ b/embed.go @@ -0,0 +1,6 @@ +package snakebot + +import _ "embed" + +//go:embed api/openapi.yaml +var OpenAPISpec string diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..0aa2646 --- /dev/null +++ b/go.mod @@ -0,0 +1,18 @@ +module github.com/ivan1993spb/snake-bot + +go 1.16 + +require ( + github.com/go-chi/chi/v5 v5.0.5 + github.com/go-chi/cors v1.2.0 + github.com/gorilla/websocket v1.4.2 + github.com/kr/text v0.2.0 // indirect + github.com/pkg/errors v0.9.1 + github.com/prometheus/client_golang v1.11.0 + github.com/sirupsen/logrus v1.8.1 + github.com/stretchr/testify v1.7.0 + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect + lukechampine.com/noescape v0.0.0-20191006153127-214c369a3d1b +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ca64032 --- /dev/null +++ b/go.sum @@ -0,0 +1,167 @@ +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-chi/chi/v5 v5.0.5 h1:l3RJ8T8TAqLsXFfah+RA6N4pydMbPwSdvNM+AFWvLUM= +github.com/go-chi/chi/v5 v5.0.5/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-chi/cors v1.2.0 h1:tV1g1XENQ8ku4Bq3K9ub2AtgG+p16SmzeMSGTwrOKdE= +github.com/go-chi/cors v1.2.0/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= +github.com/prometheus/client_golang v1.11.0 h1:HNkLOAEQMIDv/K+04rukrLx6ch7msSRwf3/SASFAGtQ= +github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= +github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= +github.com/prometheus/common v0.26.0 h1:iMAkS2TDoNWnKM+Kopnx/8tnEStIfpYA0ur0xQzzhMQ= +github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= +github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4= +github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= +github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40 h1:JWgyZ1qgdTaF3N3oxC+MdTV7qvEEgHo3otj+HB5CM7Q= +golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1 h1:7QnIQpGRHE5RnLKnESfDoxm2dTapTZua5a0kS0A+VXQ= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +lukechampine.com/noescape v0.0.0-20191006153127-214c369a3d1b h1:P9+2fy9NHwuZ/vkYOU47qdN+QwN5HHlQSBH5XLakJNE= +lukechampine.com/noescape v0.0.0-20191006153127-214c369a3d1b/go.mod h1:t3zg9hUEgo5ticiCFz1Q9KuGGhJgIJ1iPJzWukPtEh8= diff --git a/internal/bot/bot.go b/internal/bot/bot.go new file mode 100644 index 0000000..fcf97b1 --- /dev/null +++ b/internal/bot/bot.go @@ -0,0 +1,150 @@ +package bot + +import ( + "context" + "math/rand" + "sync/atomic" + "time" + + "github.com/sirupsen/logrus" + + "github.com/ivan1993spb/snake-bot/internal/bot/engine" + "github.com/ivan1993spb/snake-bot/internal/types" + "github.com/ivan1993spb/snake-bot/internal/utils" +) + +const botTickTime = time.Millisecond * 200 + +const lookupDistance uint8 = 50 + +const directionExpireTime = time.Millisecond * 10 + +const ( + stateWait = iota + stateExplore +) + +type World interface { + LookAround(sight engine.Sight) *engine.HashmapSight + GetArea() engine.Area + GetObject(id uint32) (*types.Object, bool) +} + +type Bot struct { + state uint32 + + myId uint32 + world World + + discoverer engine.Discoverer + + lastPosition types.Dot + lastDirection types.Direction +} + +func NewBot(world World, discoverer engine.Discoverer) *Bot { + return &Bot{ + world: world, + discoverer: discoverer, + } +} + +func NewDijkstrasBot(world World) *Bot { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + discoverer := engine.NewDijkstrasDiscoverer(r) + cacher := engine.NewCacherDiscoverer(discoverer) + return NewBot(world, cacher) +} + +func (b *Bot) Run(ctx context.Context, stop <-chan struct{}) <-chan types.Direction { + chout := make(chan types.Direction) + + go func() { + defer close(chout) + ticker := time.NewTicker(botTickTime) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + direction, ok := b.operate(ctx) + if ok { + ctx, cancel := context.WithTimeout(ctx, directionExpireTime) + select { + case chout <- direction: + case <-ctx.Done(): + } + cancel() + } + case <-stop: + return + case <-ctx.Done(): + return + } + } + }() + + return chout +} + +func (b *Bot) getState() uint32 { + return atomic.LoadUint32(&b.state) +} + +func (b *Bot) getMe() uint32 { + return atomic.LoadUint32(&b.myId) +} + +func (b *Bot) operate(ctx context.Context) (types.Direction, bool) { + if b.getState() != stateExplore { + return types.DirectionZero, false + } + + me, ok := b.world.GetObject(b.getMe()) + if !ok { + return types.DirectionZero, false + } + + area := b.world.GetArea() + objectDots := me.GetDots() + if len(objectDots) == 0 { + return types.DirectionZero, false + } + head := objectDots[0] + sight := engine.NewSight(area, head, lookupDistance) + + objects := b.world.LookAround(sight) + scores := b.score(objects) + + path := b.discoverer.Discover(head, area, sight, scores) + if len(path) == 0 { + return types.DirectionZero, false + } + direction := area.FindDirection(head, path[0]) + + if area.FindDirection(objectDots[1], objectDots[0]) == direction { + // the same direction + return types.DirectionZero, false + } + if b.lastDirection != direction || b.lastPosition != head { + utils.Log(ctx).WithFields(logrus.Fields{ + "head": head, + "direction": direction, + }).Debugln("change direction") + b.lastDirection = direction + b.lastPosition = head + + return direction, true + } + return types.DirectionZero, false +} + +func (b *Bot) Countdown(sec int) { + // TODO: Shut down the snake if the stateWait + atomic.StoreUint32(&b.state, stateWait) +} + +func (b *Bot) Me(id uint32) { + atomic.StoreUint32(&b.myId, id) + atomic.StoreUint32(&b.state, stateExplore) +} diff --git a/internal/bot/bot_score.go b/internal/bot/bot_score.go new file mode 100644 index 0000000..d27ef36 --- /dev/null +++ b/internal/bot/bot_score.go @@ -0,0 +1,67 @@ +package bot + +import ( + "github.com/ivan1993spb/snake-bot/internal/bot/engine" + "github.com/ivan1993spb/snake-bot/internal/types" +) + +type scoreType uint8 + +const ( + scoreTypeCollapse scoreType = iota + scoreTypeFoodApple + scoreTypeFoodCorpse + scoreTypeFoodWatermelon + scoreTypeFoodMouse + scoreTypeHunt +) + +var defaultBehavior = map[scoreType]int{ + scoreTypeCollapse: -1000, + scoreTypeFoodApple: 1, + scoreTypeFoodCorpse: 2, + scoreTypeFoodWatermelon: 5, + scoreTypeFoodMouse: 15, + scoreTypeHunt: 30, +} + +func (b *Bot) score(objects *engine.HashmapSight) *engine.HashmapSight { + // TODO: Estimate the real prey length. + const preyLen = 3 + + scores := objects.Reflect() + + objects.ForEach(func(dot types.Dot, v interface{}) { + object, ok := v.(*types.Object) + if !ok { + return + } + if object.Id == b.myId { + scores.Assign(dot, defaultBehavior[scoreTypeCollapse]) + } else if object.Type == types.ObjectTypeSnake { + if len(object.Dots) == preyLen { + scores.Assign(dot, defaultBehavior[scoreTypeHunt]) + } else { + scores.Assign(dot, defaultBehavior[scoreTypeCollapse]) + } + } else if object.Type == types.ObjectTypeWall { + scores.Assign(dot, defaultBehavior[scoreTypeCollapse]) + } else { + switch object.Type { + case types.ObjectTypeApple: + scores.Assign(dot, defaultBehavior[scoreTypeFoodApple]) + case types.ObjectTypeCorpse: + scores.Assign(dot, defaultBehavior[scoreTypeFoodCorpse]) + case types.ObjectTypeWatermelon: + scores.Assign(dot, defaultBehavior[scoreTypeFoodWatermelon]) + case types.ObjectTypeMouse: + scores.Assign(dot, defaultBehavior[scoreTypeFoodMouse]) + default: + // Avoid unknown objects. + scores.Assign(dot, defaultBehavior[scoreTypeCollapse]) + } + } + }) + + return scores +} diff --git a/internal/bot/engine/area.go b/internal/bot/engine/area.go new file mode 100644 index 0000000..1426b7f --- /dev/null +++ b/internal/bot/engine/area.go @@ -0,0 +1,142 @@ +package engine + +import ( + "fmt" + + "github.com/ivan1993spb/snake-bot/internal/types" +) + +type Area struct { + Width uint8 + Height uint8 +} + +func NewArea(width, height uint8) Area { + return Area{ + Width: width, + Height: height, + } +} + +func (a Area) Navigate(dot types.Dot) []types.Dot { + dots := make([]types.Dot, 0, 4) + + // North + if dot.Y > 0 { + dots = append(dots, types.Dot{ + X: dot.X, + Y: dot.Y - 1, + }) + } else { + dots = append(dots, types.Dot{ + X: dot.X, + Y: a.Height - 1, + }) + } + + // South + dots = append(dots, types.Dot{ + X: dot.X, + Y: (dot.Y + 1) % a.Height, + }) + + // East + dots = append(dots, types.Dot{ + X: (dot.X + 1) % a.Width, + Y: dot.Y, + }) + + // West + if dot.X > 0 { + dots = append(dots, types.Dot{ + X: dot.X - 1, + Y: dot.Y, + }) + } else { + dots = append(dots, types.Dot{ + X: a.Width - 1, + Y: dot.Y, + }) + } + + return dots +} + +func (a Area) FindDirection(from, to types.Dot) types.Direction { + if from == to || !a.Fits(from) || !a.Fits(to) { + return types.DirectionNorth + } + + var ( + minDiffX, minDiffY uint8 + dirX, dirY types.Direction + ) + + if from.X > to.X { + minDiffX = from.X - to.X + dirX = types.DirectionWest + + if diff := a.Width - from.X + to.X; diff < minDiffX { + minDiffX = diff + dirX = types.DirectionEast + } + } else { + minDiffX = to.X - from.X + dirX = types.DirectionEast + + if diff := a.Width - to.X + from.X; diff < minDiffX { + minDiffX = diff + dirX = types.DirectionWest + } + } + + if from.Y > to.Y { + minDiffY = from.Y - to.Y + dirY = types.DirectionNorth + + if diff := a.Height - from.Y + to.Y; diff < minDiffY { + minDiffY = diff + dirY = types.DirectionSouth + } + } else { + minDiffY = to.Y - from.Y + dirY = types.DirectionSouth + + if diff := a.Height - to.Y + from.Y; diff < minDiffY { + minDiffY = diff + dirY = types.DirectionNorth + } + } + + if minDiffX > minDiffY { + return dirX + } + + return dirY +} + +func (a Area) FitDistance(distance, divisor, gap uint8) uint8 { + if distance >= a.Width/divisor { + distance = a.Width / divisor + if gap > 0 && a.Width%divisor == 0 { + distance -= gap + } + } + + if distance >= a.Height/divisor { + distance = a.Height / divisor + if gap > 0 && a.Height%divisor == 0 { + distance -= gap + } + } + + return distance +} + +func (a Area) Fits(dot types.Dot) bool { + return a.Width > dot.X && a.Height > dot.Y +} + +func (a Area) String() string { + return fmt.Sprintf("Area[%d, %d]", a.Width, a.Height) +} diff --git a/internal/bot/engine/area_test.go b/internal/bot/engine/area_test.go new file mode 100644 index 0000000..c189cb7 --- /dev/null +++ b/internal/bot/engine/area_test.go @@ -0,0 +1,229 @@ +package engine + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/ivan1993spb/snake-bot/internal/types" +) + +func Test_NewArea(t *testing.T) { + const ( + w = 132 + h = 43 + ) + a := NewArea(w, h) + assert.Equal(t, uint8(w), a.Width) + assert.Equal(t, uint8(h), a.Height) +} + +func Test_Area_Navigate(t *testing.T) { + const msgFormat = "%d: %s.Navigate(%s) = %s" + + tests := []struct { + area Area + input types.Dot + output []types.Dot + }{ + { + area: Area{ + Width: 100, + Height: 101, + }, + input: types.Dot{ + X: 0, + Y: 0, + }, + output: []types.Dot{ + { + X: 0, + Y: 100, + }, + { + X: 0, + Y: 1, + }, + { + X: 1, + Y: 0, + }, + { + X: 99, + Y: 0, + }, + }, + }, + { + area: Area{ + Width: 100, + Height: 101, + }, + input: types.Dot{ + X: 99, + Y: 0, + }, + output: []types.Dot{ + { + X: 99, + Y: 100, + }, + { + X: 99, + Y: 1, + }, + { + X: 0, + Y: 0, + }, + { + X: 98, + Y: 0, + }, + }, + }, + { + area: Area{ + Width: 100, + Height: 101, + }, + input: types.Dot{ + X: 99, + Y: 100, + }, + output: []types.Dot{ + { + X: 99, + Y: 99, + }, + { + X: 99, + Y: 0, + }, + { + X: 0, + Y: 100, + }, + { + X: 98, + Y: 100, + }, + }, + }, + { + area: Area{ + Width: 55, + Height: 24, + }, + input: types.Dot{ + X: 0, + Y: 23, + }, + output: []types.Dot{ + { + X: 0, + Y: 22, + }, + { + X: 0, + Y: 0, + }, + { + X: 1, + Y: 23, + }, + { + X: 54, + Y: 23, + }, + }, + }, + { + area: Area{ + Width: 55, + Height: 24, + }, + input: types.Dot{ + X: 12, + Y: 4, + }, + output: []types.Dot{ + { + X: 12, + Y: 3, + }, + { + X: 12, + Y: 5, + }, + { + X: 13, + Y: 4, + }, + { + X: 11, + Y: 4, + }, + }, + }, + } + + for i, test := range tests { + dots := test.area.Navigate(test.input) + assert.Equalf(t, test.output, dots, msgFormat, + i+1, test.area, test.input, dots) + } +} + +func Test_Area_FindDirection(t *testing.T) { + const msgFormat = "%d: %s.FindDirection(%s, %s) = %s" + + tests := []struct { + area Area + from types.Dot + to types.Dot + expect types.Direction + }{ + { + area: Area{ + Width: 10, + Height: 11, + }, + from: types.Dot{}, + to: types.Dot{}, + expect: types.DirectionNorth, + }, + { + area: Area{ + Width: 10, + Height: 11, + }, + from: types.Dot{}, + to: types.Dot{ + X: 12, + Y: 23, + }, + expect: types.DirectionNorth, + }, + { + area: Area{ + Width: 10, + Height: 11, + }, + from: types.Dot{ + X: 12, + Y: 23, + }, + to: types.Dot{}, + expect: types.DirectionNorth, + }, + + // TODO: Add cases. + } + + for i, test := range tests { + dir := test.area.FindDirection(test.from, test.to) + assert.Equalf(t, test.expect, dir, msgFormat, + i+1, test.area, test.from, test.to, dir) + } +} diff --git a/internal/bot/engine/backtrack.go b/internal/bot/engine/backtrack.go new file mode 100644 index 0000000..9b923e9 --- /dev/null +++ b/internal/bot/engine/backtrack.go @@ -0,0 +1,15 @@ +package engine + +import "github.com/ivan1993spb/snake-bot/internal/types" + +func backtrack(from, to types.Dot, distance int, + prev map[types.Dot]types.Dot) []types.Dot { + path := make([]types.Dot, distance) + current := to + for from != current { + distance-- + path[distance] = current + current = prev[current] + } + return path +} diff --git a/internal/bot/engine/backtrack_test.go b/internal/bot/engine/backtrack_test.go new file mode 100644 index 0000000..bacaff6 --- /dev/null +++ b/internal/bot/engine/backtrack_test.go @@ -0,0 +1,8 @@ +package engine + +import "testing" + +func Test_backtrack(t *testing.T) { + t.Skip() + // TODO: Implement. +} diff --git a/internal/bot/engine/cacher_discoverer.go b/internal/bot/engine/cacher_discoverer.go new file mode 100644 index 0000000..ecbe5cf --- /dev/null +++ b/internal/bot/engine/cacher_discoverer.go @@ -0,0 +1,82 @@ +package engine + +import ( + "github.com/ivan1993spb/snake-bot/internal/types" +) + +// TODO: Consider using cache size. + +var _ Discoverer = (*CacherDiscoverer)(nil) + +type CacherDiscoverer struct { + discoverer Discoverer + + path []types.Dot + scores []int + + expected int +} + +func NewCacherDiscoverer(discoverer Discoverer) Discoverer { + return &CacherDiscoverer{ + discoverer: discoverer, + } +} + +func (d *CacherDiscoverer) Discover(head types.Dot, area Area, + sight Sight, scores *HashmapSight) []types.Dot { + + d.update(head) + + if d.expired(scores) { + d.path = d.discoverer.Discover(head, area, sight, scores) + d.score(scores) + } + + return d.path +} + +func (d *CacherDiscoverer) update(head types.Dot) { + if len(d.path) == 0 { + return + } + if i := d.index(head); i > -1 { + d.path = d.path[i+1:] + d.scores = d.scores[i+1:] + } +} + +func (d *CacherDiscoverer) index(head types.Dot) int { + for i, dot := range d.path { + if dot == head { + return i + } + } + return -1 +} + +const visibilityThreshold = 10 + +func (d *CacherDiscoverer) expired(scores *HashmapSight) bool { + if len(d.path) < visibilityThreshold || d.expected <= 0 { + return true + } + for i, dot := range d.path { + score, _ := scores.AccessDefault(dot, 0).(int) + if d.scores[i] != score { + return true + } + } + return false +} + +func (d *CacherDiscoverer) score(scores *HashmapSight) { + d.expected = 0 + d.scores = make([]int, len(d.path)) + + for i, dot := range d.path { + score, _ := scores.AccessDefault(dot, 0).(int) + d.scores[i] = score + d.expected += d.scores[i] + } +} diff --git a/internal/bot/engine/cacher_discoverer_test.go b/internal/bot/engine/cacher_discoverer_test.go new file mode 100644 index 0000000..f3dc94a --- /dev/null +++ b/internal/bot/engine/cacher_discoverer_test.go @@ -0,0 +1,85 @@ +package engine + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/ivan1993spb/snake-bot/internal/types" +) + +func TestCacherDiscoverer_update_ReturnsIfPathEmpty(t *testing.T) { + d := &CacherDiscoverer{ + path: []types.Dot{}, + scores: []int{}, + } + + head := types.Dot{1, 3} + + d.update(head) +} + +func TestCacherDiscoverer_update_CutsPath(t *testing.T) { + d := &CacherDiscoverer{ + path: []types.Dot{ + {1, 0}, + {1, 1}, + {1, 2}, + {1, 3}, + {1, 4}, + {1, 5}, + }, + scores: []int{ + 1, + 0, + 3, + 1, + 0, + 0, + }, + } + + head := types.Dot{1, 3} + + d.update(head) + + expectPath := []types.Dot{ + {1, 4}, + {1, 5}, + } + expectScores := []int{ + 0, + 0, + } + + assert.Equal(t, expectPath, d.path) + assert.Equal(t, expectScores, d.scores) +} + +func TestCacherDiscoverer_update_RemovesCache(t *testing.T) { + d := &CacherDiscoverer{ + path: []types.Dot{ + {1, 0}, + {1, 1}, + {1, 2}, + {1, 3}, + {1, 4}, + {1, 5}, + }, + scores: []int{ + 1, + 0, + 3, + 1, + 0, + 0, + }, + } + + head := types.Dot{1, 5} + + d.update(head) + + assert.Empty(t, d.path) + assert.Empty(t, d.scores) +} diff --git a/internal/bot/engine/dijkstras_discoverer.go b/internal/bot/engine/dijkstras_discoverer.go new file mode 100644 index 0000000..7124590 --- /dev/null +++ b/internal/bot/engine/dijkstras_discoverer.go @@ -0,0 +1,86 @@ +package engine + +import ( + "container/heap" + "math/rand" + + "github.com/ivan1993spb/snake-bot/internal/types" +) + +const maxPathLength = 50 + +var _ Discoverer = (*DijkstrasDiscoverer)(nil) + +// DijkstrasDiscoverer discoves paths. It is based on +// Dijkstra's algorithm. +// +// Link: https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm +type DijkstrasDiscoverer struct { + random *rand.Rand +} + +func NewDijkstrasDiscoverer(r *rand.Rand) Discoverer { + return &DijkstrasDiscoverer{ + random: r, + } +} + +func (d *DijkstrasDiscoverer) Discover(head types.Dot, area Area, + sight Sight, scores *HashmapSight) []types.Dot { + visited := make(map[types.Dot]struct{}) + prev := make(map[types.Dot]types.Dot) + + // Start with a negative score! In this case the algorithm + // won't conclude that the current position is the best + // even if all the observed dots are empty. + best := &Position{ + dot: head, + score: -1, + distance: 0, + } + queue := NewQueue() + queue.Push(best) + heap.Init(queue) + + for queue.Len() > 0 { + position := heap.Pop(queue).(*Position) + dots := area.Navigate(position.dot) + d.shuffle(dots) + + for _, dot := range dots { + if !sight.Seen(dot) || queue.Enqueued(dot) { + continue + } + if _, ok := visited[dot]; ok { + continue + } + score, _ := scores.AccessDefault(dot, 0).(int) + if score >= 0 { + prev[dot] = position.dot + distance := position.distance + 1 + + if distance < maxPathLength { + heap.Push(queue, &Position{ + dot: dot, + score: position.score + score, + distance: distance, + }) + } + } + } + + visited[position.dot] = struct{}{} + if best.score < position.score { + best = position + } + } + + path := backtrack(head, best.dot, best.distance, prev) + return path +} + +func (d *DijkstrasDiscoverer) shuffle(dots []types.Dot) { + d.random.Shuffle(len(dots), func(i, j int) { + dots[i], dots[j] = dots[j], dots[i] + }) +} diff --git a/internal/bot/engine/discoverer.go b/internal/bot/engine/discoverer.go new file mode 100644 index 0000000..3dd96b0 --- /dev/null +++ b/internal/bot/engine/discoverer.go @@ -0,0 +1,8 @@ +package engine + +import "github.com/ivan1993spb/snake-bot/internal/types" + +type Discoverer interface { + Discover(head types.Dot, area Area, sight Sight, + scores *HashmapSight) []types.Dot +} diff --git a/internal/bot/engine/hashmap_sight.go b/internal/bot/engine/hashmap_sight.go new file mode 100644 index 0000000..59f1b5f --- /dev/null +++ b/internal/bot/engine/hashmap_sight.go @@ -0,0 +1,94 @@ +package engine + +import ( + "sync/atomic" + "unsafe" + + "github.com/ivan1993spb/snake-bot/internal/types" +) + +type HashmapSight struct { + sight Sight + values []unsafe.Pointer +} + +func NewHashmapSight(s Sight) *HashmapSight { + w := int(s.zeroedBottomRight.X) + 1 + h := int(s.zeroedBottomRight.Y) + 1 + values := make([]unsafe.Pointer, w*h) + + return &HashmapSight{ + sight: s, + values: values, + } +} + +func (h *HashmapSight) hash(dot types.Dot) int { + if !h.sight.Seen(dot) { + return -1 + } + + x, y := h.sight.Relative(dot) + hash := int(y)*(int(h.sight.zeroedBottomRight.X)+1) + int(x) + + return hash +} + +func (h *HashmapSight) unhash(i int) types.Dot { + if i < 0 { + return types.Dot{} + } + + x := uint8(i % (int(h.sight.zeroedBottomRight.X) + 1)) + y := uint8(i / (int(h.sight.zeroedBottomRight.X) + 1)) + + return h.sight.Absolute(x, y) +} + +func (h *HashmapSight) Access(dot types.Dot) (interface{}, bool) { + if hash := h.hash(dot); hash > -1 { + p := atomic.LoadPointer(&h.values[hash]) + if uintptr(p) == uintptr(0) { + return nil, false + } + return *(*interface{})(p), true + } + return nil, false +} + +func (h *HashmapSight) AccessDefault(dot types.Dot, + defaultValue interface{}) interface{} { + if v, ok := h.Access(dot); ok { + return v + } + return defaultValue + +} + +func (h *HashmapSight) Assign(dot types.Dot, v interface{}) { + if hash := h.hash(dot); hash > -1 { + atomic.SwapPointer(&h.values[hash], unsafe.Pointer(&v)) + } +} + +func (h *HashmapSight) Flush() { + for i := 0; i < len(h.values); i++ { + atomic.SwapPointer(&h.values[i], unsafe.Pointer(uintptr(0))) + } +} + +func (h *HashmapSight) ForEach(f func(dot types.Dot, v interface{})) { + for i := 0; i < len(h.values); i++ { + p := atomic.LoadPointer(&h.values[i]) + if uintptr(p) != uintptr(0) { + f(h.unhash(i), *(*interface{})(p)) + } + } +} + +func (h *HashmapSight) Reflect() *HashmapSight { + return &HashmapSight{ + sight: h.sight, + values: make([]unsafe.Pointer, len(h.values)), + } +} diff --git a/internal/bot/engine/hashmap_sight_bench_test.go b/internal/bot/engine/hashmap_sight_bench_test.go new file mode 100644 index 0000000..b032c6d --- /dev/null +++ b/internal/bot/engine/hashmap_sight_bench_test.go @@ -0,0 +1,151 @@ +package engine + +import ( + "testing" + + "github.com/ivan1993spb/snake-bot/internal/types" +) + +func Benchmark_HashmapSight_Assign(b *testing.B) { + const ( + width = 255 + height = 255 + + distance = 100 + ) + + a := NewArea(width, height) + pos := types.Dot{ + X: 0, + Y: 0, + } + s := NewSight(a, pos, distance) + dots := s.Dots() + l := len(dots) + + h := NewHashmapSight(s) + + val := struct { + a, b, c int + d, e, f interface{} + g, h, i uint8 + }{} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + dot := dots[i%l] + h.Assign(dot, val) + } +} + +func Benchmark_Map_Assign(b *testing.B) { + const ( + width = 255 + height = 255 + + distance = 100 + ) + + a := NewArea(width, height) + pos := types.Dot{ + X: 0, + Y: 0, + } + s := NewSight(a, pos, distance) + dots := s.Dots() + l := len(dots) + + h := make(map[types.Dot]interface{}, l) + + val := struct { + a, b, c int + d, e, f interface{} + g, h, i uint8 + }{} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + dot := dots[i%l] + h[dot] = val + } +} + +func Benchmark_HashmapSight_Access(b *testing.B) { + const ( + width = 255 + height = 255 + + distance = 100 + ) + + a := NewArea(width, height) + pos := types.Dot{ + X: 0, + Y: 0, + } + s := NewSight(a, pos, distance) + dots := s.Dots() + l := len(dots) + + h := NewHashmapSight(s) + + val := struct { + a, b, c int + d, e, f interface{} + g, h, i uint8 + }{} + + for _, dot := range dots { + h.Assign(dot, val) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + dot := dots[i%l] + _, _ = h.Access(dot) + } +} + +func Benchmark_Map_Access(b *testing.B) { + const ( + width = 255 + height = 255 + + distance = 100 + ) + + a := NewArea(width, height) + pos := types.Dot{ + X: 0, + Y: 0, + } + s := NewSight(a, pos, distance) + dots := s.Dots() + l := len(dots) + + h := make(map[types.Dot]interface{}, l) + + val := struct { + a, b, c int + d, e, f interface{} + g, h, i uint8 + }{} + + for _, dot := range dots { + h[dot] = val + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + dot := dots[i%l] + _, _ = h[dot] + } +} diff --git a/internal/bot/engine/hashmap_sight_test.go b/internal/bot/engine/hashmap_sight_test.go new file mode 100644 index 0000000..fa274ff --- /dev/null +++ b/internal/bot/engine/hashmap_sight_test.go @@ -0,0 +1,192 @@ +package engine + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ivan1993spb/snake-bot/internal/types" +) + +func TestHashmapSight_hash(t *testing.T) { + const msgFormat = "hash(%s) = %d" + const ( + width = 30 + height = 22 + ) + + a := NewArea(width, height) + d := uint8(4) + for posX := uint8(0); posX < a.Width; posX++ { + for posY := uint8(0); posY < a.Height; posY++ { + pos := types.Dot{ + X: posX, + Y: posY, + } + s := NewSight(a, pos, d) + h := NewHashmapSight(s) + + for x := uint8(0); x < a.Width; x++ { + for y := uint8(0); y < a.Height; y++ { + dot := types.Dot{ + X: x, + Y: y, + } + res := h.hash(dot) + require.Truef(t, res == -1 || res < len(h.values), + msgFormat, dot, res) + } + } + } + } +} + +func TestHashmapSight_UniqHashes(t *testing.T) { + const msgFormat = "hash(%s) = %d (the same %s)" + + const ( + width = 255 + height = 255 + ) + + a := NewArea(width, height) + d := uint8(1) + pos := types.Dot{ + X: 0, + Y: 0, + } + s := NewSight(a, pos, d) + h := NewHashmapSight(s) + + check := map[int]types.Dot{} + + for _, dot := range s.Dots() { + hash := h.hash(dot) + dot1, ok := check[hash] + assert.Falsef(t, ok, msgFormat, dot, hash, dot1) + check[hash] = dot + } +} + +func TestHashmapSight_AssignAccessFlush(t *testing.T) { + const ( + width = 30 + height = 22 + ) + + a := NewArea(width, height) + d := uint8(4) + pos := types.Dot{ + X: 2, + Y: 4, + } + pos1 := types.Dot{ + X: 3, + Y: 4, + } + s := NewSight(a, pos, d) + h := NewHashmapSight(s) + h.Assign(pos, 2) + + val, ok := h.Access(pos) + require.True(t, ok) + require.Equal(t, 2, val) + + val1, ok := h.Access(pos1) + require.False(t, ok) + require.Nil(t, val1) + + h.Flush() + + val2, ok := h.Access(pos) + require.False(t, ok) + require.Nil(t, val2) +} + +func TestHashmapSight_hash_unhash(t *testing.T) { + const msgFormat = "hash(%s) = %d; unhash(%d) = %s" + + const ( + width = 30 + height = 22 + ) + + a := NewArea(width, height) + d := uint8(4) + pos := types.Dot{ + X: 1, + Y: 0, + } + s := NewSight(a, pos, d) + h := NewHashmapSight(s) + + for _, dot := range s.Dots() { + hash := h.hash(dot) + unhash := h.unhash(hash) + require.Equalf(t, dot, unhash, msgFormat, + dot, hash, hash, unhash) + } +} + +func TestHashmapSight_AccessReturnsFalseIfEmpty(t *testing.T) { + const msgFormat = "case number %d" + + const ( + width = 220 + height = 147 + ) + + a := NewArea(width, height) + d := uint8(50) + pos := types.Dot{ + X: 1, + Y: 0, + } + s := NewSight(a, pos, d) + h := NewHashmapSight(s) + + for i, dot := range s.Dots() { + _, ok := h.Access(dot) + require.Falsef(t, ok, msgFormat, i+1) + h.Assign(dot, "ok") + } +} + +func TestHashmapSight_CompareWithMap(t *testing.T) { + const ( + width = 220 + height = 147 + ) + + a := NewArea(width, height) + d := uint8(50) + pos := types.Dot{ + X: 1, + Y: 0, + } + s := NewSight(a, pos, d) + h := NewHashmapSight(s) + + m := map[types.Dot]interface{}{} + + for i, dot := range s.Dots() { + _, ok := h.Access(dot) + assert.False(t, ok) + + m[dot] = i + h.Assign(dot, i) + } + + h.ForEach(func(dot types.Dot, v interface{}) { + v1, ok := m[dot] + assert.True(t, ok) + assert.Equal(t, v, v1) + }) + + for dot, v := range m { + v1, ok := h.Access(dot) + assert.True(t, ok) + assert.Equal(t, v, v1) + } +} diff --git a/internal/bot/engine/map.go b/internal/bot/engine/map.go new file mode 100644 index 0000000..5f0e2a8 --- /dev/null +++ b/internal/bot/engine/map.go @@ -0,0 +1,83 @@ +package engine + +import ( + "sync/atomic" + "unsafe" + + "github.com/ivan1993spb/snake-bot/internal/types" +) + +type Map struct { + fields [][]unsafe.Pointer + area Area +} + +func NewMap(a Area) *Map { + m := make([][]unsafe.Pointer, a.Height) + + for y := uint8(0); y < a.Height; y++ { + m[y] = make([]unsafe.Pointer, a.Width) + + for x := uint8(0); x < a.Width; x++ { + m[y][x] = unsafe.Pointer(uintptr(0)) + } + } + + return &Map{ + fields: m, + area: a, + } +} + +func (m *Map) LookAround(s Sight) *HashmapSight { + if m.area != s.area { + return nil + } + objects := NewHashmapSight(s) + for _, dot := range s.Dots() { + if object, ok := m.GetObject(dot); ok { + objects.Assign(dot, object) + } + } + return objects +} + +func (m *Map) GetObject(dot types.Dot) (*types.Object, bool) { + if !m.area.Fits(dot) { + return nil, false + } + + p := atomic.LoadPointer(&m.fields[dot.Y][dot.X]) + + if fieldIsEmpty(p) { + return nil, false + } + + container := (*types.Object)(p) + + return container, true +} + +// fieldIsEmpty returns true if the pointer p is empty +func fieldIsEmpty(p unsafe.Pointer) bool { + return uintptr(p) == uintptr(0) +} + +func (m *Map) SaveObject(object *types.Object) { + for _, dot := range object.GetDots() { + if m.area.Fits(dot) { + atomic.SwapPointer(&m.fields[dot.Y][dot.X], unsafe.Pointer(object)) + } + } +} + +func (m *Map) Clear(dots []types.Dot) { + for _, dot := range dots { + if m.area.Fits(dot) { + atomic.SwapPointer( + &m.fields[dot.Y][dot.X], + unsafe.Pointer(uintptr(0)), + ) + } + } +} diff --git a/internal/bot/engine/queue.go b/internal/bot/engine/queue.go new file mode 100644 index 0000000..d24a32d --- /dev/null +++ b/internal/bot/engine/queue.go @@ -0,0 +1,63 @@ +package engine + +import "github.com/ivan1993spb/snake-bot/internal/types" + +type Position struct { + dot types.Dot + score int + distance int +} + +type Queue struct { + positions []*Position + enqueued map[types.Dot]struct{} +} + +const queueLength = 1024 + +func NewQueue() *Queue { + return &Queue{ + positions: make([]*Position, 0, queueLength), + enqueued: make(map[types.Dot]struct{}), + } +} + +func (q *Queue) Len() int { + return len(q.positions) +} + +func (q *Queue) Less(i, j int) bool { + if q.positions[i].score > q.positions[j].score { + return true + } + if q.positions[i].score < q.positions[j].score { + return false + } + return q.positions[i].distance < q.positions[j].distance + +} + +func (q *Queue) Swap(i, j int) { + q.positions[i], q.positions[j] = q.positions[j], q.positions[i] +} + +func (q *Queue) Push(x interface{}) { + item := x.(*Position) + q.positions = append(q.positions, item) + q.enqueued[item.dot] = struct{}{} +} + +func (q *Queue) Pop() interface{} { + n := len(q.positions) + item := q.positions[n-1] + q.positions[n-1] = nil + q.positions = q.positions[0 : n-1] + delete(q.enqueued, item.dot) + return item +} + +func (q *Queue) Enqueued(dot types.Dot) bool { + _, ok := q.enqueued[dot] + return ok + +} diff --git a/internal/bot/engine/sight.go b/internal/bot/engine/sight.go new file mode 100644 index 0000000..ade5236 --- /dev/null +++ b/internal/bot/engine/sight.go @@ -0,0 +1,135 @@ +package engine + +import ( + "github.com/ivan1993spb/snake-bot/internal/types" +) + +type Sight struct { + area Area + topLeft types.Dot + + zeroedBottomRight types.Dot + + width uint8 + height uint8 +} + +// sightDivisor defines how many intervals of a given length +// we need to be able to fit within an area. In the 2D space +// the divisor has to be 2: +// ___ +// ^ +// | +// | +// distance +// | +// | +// v +// |<---distance--->GAP<---distance--->| +// ^ +// | +// | +// distance +// | +// | +// v +// --- +const sightDivisor = 2 + +const sightGap = 1 + +func NewSight(a Area, pos types.Dot, distance uint8) Sight { + distance = a.FitDistance(distance, sightDivisor, sightGap) + + topLeft := types.Dot{ + X: pos.X - distance, + Y: pos.Y - distance, + } + if pos.X < distance { + topLeft.X += a.Width + } + if pos.Y < distance { + topLeft.Y += a.Height + } + + bottomRight := types.Dot{ + X: (pos.X + distance) % a.Width, + Y: (pos.Y + distance) % a.Height, + } + + zeroedBottomRight := types.Dot{ + X: bottomRight.X - topLeft.X, + Y: bottomRight.Y - topLeft.Y, + } + if bottomRight.X < topLeft.X { + zeroedBottomRight.X += a.Width + } + if bottomRight.Y < topLeft.Y { + zeroedBottomRight.Y += a.Height + } + + return Sight{ + area: a, + topLeft: topLeft, + + zeroedBottomRight: zeroedBottomRight, + + // TODO: Check on overflow. + //width: zeroedBottomRight.X + 1, + //height: zeroedBottomRight.Y + 1, + } +} + +func (s Sight) Absolute(relX, relY uint8) types.Dot { + x := uint16(s.topLeft.X) + uint16(relX) + y := uint16(s.topLeft.Y) + uint16(relY) + return types.Dot{ + X: uint8(x % uint16(s.area.Width)), + Y: uint8(y % uint16(s.area.Height)), + } +} + +func (s Sight) Relative(dot types.Dot) (uint8, uint8) { + x := dot.X - s.topLeft.X + if s.topLeft.X > dot.X { + x += s.area.Width + } + y := dot.Y - s.topLeft.Y + if s.topLeft.Y > dot.Y { + y += s.area.Height + } + return x, y +} + +func (s Sight) Seen(dot types.Dot) bool { + if zeroedX := dot.X - s.topLeft.X; dot.X >= s.topLeft.X { + if zeroedX > s.zeroedBottomRight.X { + return false + } + } else if zeroedX+s.area.Width > s.zeroedBottomRight.X { + return false + } + if zeroedY := dot.Y - s.topLeft.Y; dot.Y >= s.topLeft.Y { + if zeroedY > s.zeroedBottomRight.Y { + return false + } + } else if zeroedY+s.area.Height > s.zeroedBottomRight.Y { + return false + } + return true +} + +func (s Sight) Dots() []types.Dot { + w := (int(s.zeroedBottomRight.X) + 1) + h := (int(s.zeroedBottomRight.Y) + 1) + + dots := make([]types.Dot, 0, w*h) + + for x := uint8(0); x <= s.zeroedBottomRight.X; x++ { + for y := uint8(0); y <= s.zeroedBottomRight.Y; y++ { + dots = append(dots, s.Absolute(x, y)) + } + } + + return dots +} diff --git a/internal/bot/engine/sight_test.go b/internal/bot/engine/sight_test.go new file mode 100644 index 0000000..a7dceec --- /dev/null +++ b/internal/bot/engine/sight_test.go @@ -0,0 +1,409 @@ +package engine + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/ivan1993spb/snake-bot/internal/types" +) + +func Test_NewSight(t *testing.T) { + const format = "failed at test number %d" + + tests := []struct { + area Area + pos types.Dot + distance uint8 + expected Sight + }{ + { + area: Area{ + Width: 100, + Height: 50, + }, + pos: types.Dot{ + X: 10, + Y: 13, + }, + distance: 3, + expected: Sight{ + area: Area{ + Width: 100, + Height: 50, + }, + topLeft: types.Dot{ + X: 7, + Y: 10, + }, + zeroedBottomRight: types.Dot{ + X: 6, + Y: 6, + }, + }, + }, + { + area: Area{ + Width: 3, + Height: 3, + }, + pos: types.Dot{ + X: 1, + Y: 1, + }, + distance: 1, + expected: Sight{ + area: Area{ + Width: 3, + Height: 3, + }, + topLeft: types.Dot{ + X: 0, + Y: 0, + }, + zeroedBottomRight: types.Dot{ + X: 2, + Y: 2, + }, + }, + }, + { + area: Area{ + Width: 4, + Height: 4, + }, + pos: types.Dot{ + X: 1, + Y: 1, + }, + distance: 2, + expected: Sight{ + area: Area{ + Width: 4, + Height: 4, + }, + topLeft: types.Dot{ + X: 0, + Y: 0, + }, + zeroedBottomRight: types.Dot{ + X: 2, + Y: 2, + }, + }, + }, + { + area: Area{ + Width: 10, + Height: 4, + }, + pos: types.Dot{ + X: 1, + Y: 1, + }, + distance: 2, + expected: Sight{ + area: Area{ + Width: 10, + Height: 4, + }, + topLeft: types.Dot{ + X: 0, + Y: 0, + }, + zeroedBottomRight: types.Dot{ + X: 2, + Y: 2, + }, + }, + }, + { + area: Area{ + Width: 13, + Height: 32, + }, + pos: types.Dot{ + X: 1, + Y: 2, + }, + distance: 5, + expected: Sight{ + area: Area{ + Width: 13, + Height: 32, + }, + topLeft: types.Dot{ + X: 9, + Y: 29, + }, + zeroedBottomRight: types.Dot{ + X: 10, + Y: 10, + }, + }, + }, + } + + for i, test := range tests { + sight := NewSight(test.area, test.pos, test.distance) + assert.Equal(t, test.expected, sight, format, i+1) + } +} + +func Test_Sight_Absolute(t *testing.T) { + const msgFormat = "%d: x%d y%d => %s instead of %s" + + tests := []struct { + relX, relY uint8 + sight Sight + expect types.Dot + }{ + { + relX: 0, + relY: 0, + sight: Sight{ + area: Area{ + Width: 13, + Height: 32, + }, + topLeft: types.Dot{ + X: 1, + Y: 3, + }, + zeroedBottomRight: types.Dot{ + X: 10, + Y: 10, + }, + }, + expect: types.Dot{ + X: 1, + Y: 3, + }, + }, + { + relX: 4, + relY: 5, + sight: Sight{ + area: Area{ + Width: 20, + Height: 20, + }, + topLeft: types.Dot{ + X: 10, + Y: 3, + }, + zeroedBottomRight: types.Dot{ + X: 9, + Y: 9, + }, + }, + expect: types.Dot{ + X: 14, + Y: 8, + }, + }, + { + relX: 7, + relY: 1, + sight: Sight{ + area: Area{ + Width: 20, + Height: 20, + }, + topLeft: types.Dot{ + X: 15, + Y: 5, + }, + zeroedBottomRight: types.Dot{ + X: 9, + Y: 9, + }, + }, + expect: types.Dot{ + X: 2, + Y: 6, + }, + }, + { + relX: 68, + relY: 89, + sight: Sight{ + area: Area{ + Width: 255, + Height: 255, + }, + topLeft: types.Dot{ + X: 205, + Y: 178, + }, + zeroedBottomRight: types.Dot{ + X: 100, + Y: 100, + }, + }, + expect: types.Dot{ + X: 18, + Y: 12, + }, + }, + } + + for i, test := range tests { + dot := test.sight.Absolute(test.relX, test.relY) + assert.Equalf(t, test.expect, dot, msgFormat, + i+1, test.relX, test.relY, dot, test.expect) + } +} + +func Test_Sight_Relative(t *testing.T) { + // TODO: test this. + t.Skip() +} + +func Test_Sight_Relative_To_Absolute(t *testing.T) { + s := Sight{ + area: Area{ + Width: 255, + Height: 255, + }, + topLeft: types.Dot{ + X: 250, + Y: 249, + }, + zeroedBottomRight: types.Dot{ + X: 100, + Y: 100, + }, + } + + for _, dot := range s.Dots() { + x, y := s.Relative(dot) + res := s.Absolute(x, y) + assert.Equal(t, dot, res) + } +} + +func Test_Sight_Seen(t *testing.T) { + const format = "failed at test number %d" + + tests := []struct { + sight Sight + dot types.Dot + expected bool + }{ + { + sight: Sight{ + area: Area{ + Width: 13, + Height: 32, + }, + topLeft: types.Dot{ + X: 9, + Y: 29, + }, + zeroedBottomRight: types.Dot{ + X: 10, + Y: 10, + }, + }, + dot: types.Dot{ + X: 0, + Y: 0, + }, + expected: true, + }, + { + sight: Sight{ + area: Area{ + Width: 100, + Height: 50, + }, + topLeft: types.Dot{ + X: 7, + Y: 10, + }, + zeroedBottomRight: types.Dot{ + X: 6, + Y: 6, + }, + }, + dot: types.Dot{ + X: 9, + Y: 10, + }, + expected: true, + }, + { + sight: Sight{ + area: Area{ + Width: 13, + Height: 32, + }, + topLeft: types.Dot{ + X: 9, + Y: 29, + }, + zeroedBottomRight: types.Dot{ + X: 10, + Y: 10, + }, + }, + dot: types.Dot{ + X: 12, + Y: 30, + }, + expected: true, + }, + { + sight: Sight{ + area: Area{ + Width: 13, + Height: 32, + }, + topLeft: types.Dot{ + X: 9, + Y: 29, + }, + zeroedBottomRight: types.Dot{ + X: 10, + Y: 10, + }, + }, + dot: types.Dot{ + X: 21, + Y: 2, + }, + expected: false, + }, + } + + for i, test := range tests { + assert.Equal(t, test.expected, test.sight.Seen(test.dot), format, i+1) + } +} + +func Test_Sight_DotsDoesntRepeat(t *testing.T) { + const msgFormat = "index %d; dot %s" + + const ( + width = 255 + height = 255 + ) + + check := map[types.Dot]struct{}{} + + a := NewArea(width, height) + d := uint8(100) + pos := types.Dot{ + X: 0, + Y: 0, + } + s := NewSight(a, pos, d) + + for i, dot := range s.Dots() { + if _, ok := check[dot]; ok { + assert.Falsef(t, ok, msgFormat, i, dot) + } + check[dot] = struct{}{} + } +} diff --git a/internal/bot/game.go b/internal/bot/game.go new file mode 100644 index 0000000..cfe3063 --- /dev/null +++ b/internal/bot/game.go @@ -0,0 +1,99 @@ +package bot + +import ( + "sync" + + "github.com/ivan1993spb/snake-bot/internal/bot/engine" + "github.com/ivan1993spb/snake-bot/internal/types" +) + +type gameState uint8 + +const ( + gameStateWait gameState = iota + gameStateReady +) + +type Game struct { + state gameState + mux *sync.RWMutex + objects map[uint32]*types.Object + area engine.Area + _map *engine.Map +} + +func NewGame() *Game { + return &Game{ + state: gameStateWait, + mux: &sync.RWMutex{}, + objects: make(map[uint32]*types.Object), + } +} + +func (g *Game) Create(object *types.Object) { + g.mux.Lock() + defer g.mux.Unlock() + g.objects[object.Id] = object + if g.state == gameStateReady { + g._map.SaveObject(object) + } +} + +func (g *Game) Update(object *types.Object) { + g.mux.Lock() + defer g.mux.Unlock() + old := g.objects[object.Id] + g.objects[object.Id] = object + if g.state == gameStateReady { + if old != nil { + g._map.Clear(old.GetDots()) + } + g._map.SaveObject(object) + } +} + +func (g *Game) Delete(object *types.Object) { + g.mux.Lock() + defer g.mux.Unlock() + delete(g.objects, object.Id) + if g.state == gameStateReady { + g._map.Clear(object.GetDots()) + } +} + +func (g *Game) GetObject(id uint32) (*types.Object, bool) { + g.mux.RLock() + defer g.mux.RUnlock() + if object, ok := g.objects[id]; ok { + return object, true + } + return nil, false +} + +func (g *Game) LookAround(sight engine.Sight) *engine.HashmapSight { + g.mux.RLock() + defer g.mux.RUnlock() + + if g.state == gameStateReady { + return g._map.LookAround(sight) + } + + return nil +} + +func (g *Game) Size(width, height uint8) { + g.mux.Lock() + defer g.mux.Unlock() + g.area = engine.Area{ + Width: width, + Height: height, + } + g._map = engine.NewMap(g.area) + g.state = gameStateReady +} + +func (g *Game) GetArea() engine.Area { + g.mux.RLock() + defer g.mux.RUnlock() + return g.area +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..fa3b743 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,178 @@ +package config + +import ( + "flag" + "fmt" +) + +// Default values for the server's settings +const ( + defaultAddress = ":8080" + defaultForbidCORS = false + defaultDebug = false + + defaultSnakeServer = "localhost:8080" + defaultWSS = false + + defaultBotsLimit = 100 + + defaultLogEnableJSON = false + defaultLogLevel = "info" +) + +// Flag labels +const ( + flagLabelAddress = "address" + flagLabelForbidCORS = "forbid-cors" + flagLabelDebug = "debug" + + flagLabelSnakeServer = "snake-server" + flagLabelWSS = "wss" + + flagLabelBotsLimit = "bots-limit" + + flagLabelLogEnableJSON = "log-json" + flagLabelLogLevel = "log-level" +) + +// Flag usage descriptions +const ( + flagUsageAddress = "address to listen to" + flagUsageForbidCORS = "forbid cross-origin resource sharing" + flagUsageDebug = "add profiling routes" + + flagUsageSnakeServer = "snake server's address: host:port" + flagUsageWSS = "use secure web-socket connection" + + flagUsageBotsLimit = "overall bots limit" + + flagUsageLogEnableJSON = "use json logging format" + flagUsageLogLevel = "log level: panic, fatal, error, warning, info or debug" +) + +// Label names +const ( + fieldLabelAddress = "address" + fieldLabelForbidCORS = "forbid-cors" + fieldLabelDebug = "debug" + + fieldLabelSnakeServer = "snake server" + fieldLabelWSS = "wss" + + fieldLabelBotsLimit = "bots-limit" + + fieldLabelLogEnableJSON = "log-json" + fieldLabelLogLevel = "log-level" +) + +// Server structure contains configurations for the server +type Server struct { + Address string + ForbidCORS bool + Debug bool +} + +type Target struct { + Address string + WSS bool +} + +type Bots struct { + Limit int +} + +// Log structure defines preferences for logging +type Log struct { + EnableJSON bool + Level string +} + +// Config is a base server configuration structure +type Config struct { + Server Server + Target Target + Log Log + Bots Bots +} + +// Fields returns a map of all configurations +func (c Config) Fields() map[string]interface{} { + return map[string]interface{}{ + fieldLabelAddress: c.Server.Address, + fieldLabelForbidCORS: c.Server.ForbidCORS, + fieldLabelDebug: c.Server.Debug, + + fieldLabelSnakeServer: c.Target.Address, + fieldLabelWSS: c.Target.WSS, + + fieldLabelBotsLimit: c.Bots.Limit, + + fieldLabelLogEnableJSON: c.Log.EnableJSON, + fieldLabelLogLevel: c.Log.Level, + } +} + +// Default settings +var defaultConfig = Config{ + Server: Server{ + Address: defaultAddress, + ForbidCORS: defaultForbidCORS, + Debug: defaultDebug, + }, + + Target: Target{ + Address: defaultSnakeServer, + WSS: defaultWSS, + }, + + Bots: Bots{ + Limit: defaultBotsLimit, + }, + + Log: Log{ + EnableJSON: defaultLogEnableJSON, + Level: defaultLogLevel, + }, +} + +// DefaultConfig returns configuration by default +func DefaultConfig() Config { + return defaultConfig +} + +// ParseFlags parses flags and returns a config based on the default configuration +func ParseFlags(flagSet *flag.FlagSet, args []string, defaults Config) (Config, error) { + if flagSet.Parsed() { + panic("program composition error: the provided FlagSet has been parsed") + } + + config := defaults + + // Address + flagSet.StringVar(&config.Server.Address, flagLabelAddress, + defaults.Server.Address, flagUsageAddress) + flagSet.BoolVar(&config.Server.ForbidCORS, flagLabelForbidCORS, + defaults.Server.ForbidCORS, flagUsageForbidCORS) + flagSet.BoolVar(&config.Server.Debug, flagLabelDebug, + defaults.Server.Debug, flagUsageDebug) + + flagSet.StringVar(&config.Target.Address, flagLabelSnakeServer, + defaults.Target.Address, flagUsageSnakeServer) + flagSet.BoolVar(&config.Target.WSS, flagLabelWSS, + defaults.Target.WSS, flagUsageWSS) + + flagSet.IntVar(&config.Bots.Limit, flagLabelBotsLimit, + defaults.Bots.Limit, flagUsageBotsLimit) + + // Logging + flagSet.BoolVar(&config.Log.EnableJSON, flagLabelLogEnableJSON, + defaults.Log.EnableJSON, flagUsageLogEnableJSON) + flagSet.StringVar(&config.Log.Level, flagLabelLogLevel, + defaults.Log.Level, flagUsageLogLevel) + + if err := flagSet.Parse(args); err != nil { + return defaults, fmt.Errorf("cannot parse flags: %s", err) + } + + return config, nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..46e84da --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,211 @@ +package config + +import ( + "flag" + "fmt" + "io/ioutil" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_DefaultConfig_ReturnsDefaultConfig(t *testing.T) { + require.Equal(t, defaultConfig, DefaultConfig()) +} + +func Test_ParseFlags_ParsesFlagsCorrectly(t *testing.T) { + const flagSetName = "test" + + type Test struct { + msg string + + args []string + defaults Config + + expectConfig Config + expectErr bool + } + + var tests = make([]*Test, 0) + + // Test case 1 + tests = append(tests, &Test{ + msg: "run without arguments", + + args: []string{}, + defaults: defaultConfig, + + expectConfig: defaultConfig, + expectErr: false, + }) + + // Test case 2 + configTest2 := defaultConfig + configTest2.Server.Address = ":7070" + + tests = append(tests, &Test{ + msg: "change address", + + args: []string{ + "-address", ":7070", + }, + defaults: defaultConfig, + + expectConfig: configTest2, + expectErr: false, + }) + + // Test case 3 + configTest3 := defaultConfig + configTest3.Server.Address = "localhost:6670" + configTest3.Target.Address = "any-snake-dot.com" + + tests = append(tests, &Test{ + msg: "change address and server address", + + args: []string{ + "-address", "localhost:6670", + "-snake-server", "any-snake-dot.com", + }, + defaults: defaultConfig, + + expectConfig: configTest3, + expectErr: false, + }) + + // Test case 4 + configTest4 := defaultConfig + configTest4.Server.Address = "snakeonline.xyz:7986" + configTest4.Log.EnableJSON = true + configTest4.Target.WSS = true + + tests = append(tests, &Test{ + msg: "change address, logging and protocol", + + args: []string{ + "-address", "snakeonline.xyz:7986", + "-log-json", + "-wss", + }, + defaults: defaultConfig, + + expectConfig: configTest4, + expectErr: false, + }) + + // Test case 5 + configTest5 := defaultConfig + configTest5.Server.Address = "snakeonline.xyz:3211" + configTest5.Log.EnableJSON = true + configTest5.Bots.Limit = 541 + configTest5.Server.Debug = true + + tests = append(tests, &Test{ + msg: "change address, logging, bots limit and debug", + + args: []string{ + "-address", "snakeonline.xyz:3211", + "-log-json", + "-bots-limit", "541", + "-debug", + }, + defaults: defaultConfig, + + expectConfig: configTest5, + expectErr: false, + }) + + // Test case 6 + tests = append(tests, &Test{ + msg: "change address, and make 1 mistake", + + args: []string{ + "-address", "snakeonline.xyz:3211", + "-foobar", // unknown flag + }, + defaults: defaultConfig, + + expectConfig: defaultConfig, + expectErr: true, + }) + + // Test case 7 + tests = append(tests, &Test{ + msg: "change address, enable debug, make 2 mistakes", + + args: []string{ + "-address", "snakeonline.xyz:3211", + "-bots-limit", "error", // should be a number + "-foobar", + }, + defaults: defaultConfig, + + expectConfig: defaultConfig, + expectErr: true, + }) + + // Test case 8 + tests = append(tests, &Test{ + msg: "args is nil", + + args: nil, + defaults: defaultConfig, + + expectConfig: defaultConfig, + expectErr: false, + }) + + for n, test := range tests { + t.Log(test.msg) + + label := fmt.Sprintf("case number %d", n+1) + flagSet := flag.NewFlagSet(flagSetName, flag.ContinueOnError) + flagSet.SetOutput(ioutil.Discard) + + config, err := ParseFlags(flagSet, test.args, test.defaults) + + if test.expectErr { + require.NotNil(t, err, label) + } else { + require.Nil(t, err, label) + } + require.Equal(t, test.expectConfig, config, label) + } +} + +func Test_Config_Fields_ReturnsFieldsOfTheConfig(t *testing.T) { + require.Equal(t, map[string]interface{}{ + fieldLabelAddress: ":9999", + fieldLabelForbidCORS: true, + fieldLabelDebug: true, + + fieldLabelSnakeServer: "localhost:9210", + fieldLabelWSS: false, + + fieldLabelBotsLimit: 1337, + + fieldLabelLogEnableJSON: false, + fieldLabelLogLevel: "warning", + }, Config{ + Server: Server{ + Address: ":9999", + + ForbidCORS: true, + Debug: true, + }, + + Target: Target{ + Address: "localhost:9210", + WSS: false, + }, + + Bots: Bots{ + Limit: 1337, + }, + + Log: Log{ + EnableJSON: false, + Level: "warning", + }, + }.Fields()) +} diff --git a/internal/config/std_config.go b/internal/config/std_config.go new file mode 100644 index 0000000..3957715 --- /dev/null +++ b/internal/config/std_config.go @@ -0,0 +1,12 @@ +package config + +import ( + "flag" + "os" +) + +func StdConfig() (Config, error) { + f := flag.NewFlagSet(os.Args[0], flag.ExitOnError) + cfg, err := ParseFlags(f, os.Args[1:], DefaultConfig()) + return cfg, err +} diff --git a/internal/connect/connection.go b/internal/connect/connection.go new file mode 100644 index 0000000..881f609 --- /dev/null +++ b/internal/connect/connection.go @@ -0,0 +1,80 @@ +package connect + +import ( + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/pkg/errors" +) + +var _ Connection = (*connection)(nil) + +type Connection interface { + Send(data interface{}) error + Receive() ([]byte, error) + Close() error +} + +type connection struct { + conn *websocket.Conn + resp *http.Response + mux *sync.Mutex +} + +const deadlinePing = time.Second + +func (c *connection) Send(data interface{}) error { + // TODO: Ping? + //deadline := time.Now().Add(deadlinePing) + // TODO: Handle error. + //c.conn.WriteControl(websocket.PingMessage, nil, deadline) + // TODO: Need concurrency! + //c.mux.Lock() + //defer c.mux.Unlock() + err := c.conn.WriteJSON(data) + if err != nil { + return errors.Wrap(err, "sending data") + } + return nil +} + +var ErrConnectionClosed = errors.New("connection has been closed") + +var ErrWrongMsgType = errors.New("wrong message type") + +func (c *connection) Receive() ([]byte, error) { + // TODO: Need concurrency! + //c.mux.Lock() + //defer c.mux.Unlock() + messageType, message, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + return nil, ErrConnectionClosed + } + return nil, errors.Wrap(err, "reading message") + } + if messageType != websocket.TextMessage { + return nil, ErrWrongMsgType + } + return message, nil +} + +func (c *connection) Close() error { + // TODO: Need concurrency! + //c.mux.Lock() + //defer c.mux.Unlock() + { + message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + err := c.conn.WriteMessage(websocket.CloseMessage, message) + if err != nil { + return errors.Wrap(err, "writing close message") + } + } + err := c.conn.Close() + if err != nil { + return errors.Wrap(err, "closing websocket") + } + return nil +} diff --git a/internal/connect/connector.go b/internal/connect/connector.go new file mode 100644 index 0000000..1bd135e --- /dev/null +++ b/internal/connect/connector.go @@ -0,0 +1,84 @@ +package connect + +import ( + "context" + "fmt" + "net/http" + "net/url" + "sync" + "time" + + "github.com/ivan1993spb/snake-bot/internal/config" + + "github.com/gorilla/websocket" + "github.com/pkg/errors" +) + +const ( + websocketInsecureScheme = "ws" + websocketSecureScheme = "wss" +) + +const gameWebsocketPathFormat = "/ws/games/%d" + +const handshakeTimeout = 45 * time.Second + +const headerClientName = "X-Snake-Client" + +type Connector struct { + address string + wss bool + + dialer *websocket.Dialer + http.Header +} + +func NewConnector(cfg config.Target, clientName string) *Connector { + c := &Connector{ + address: cfg.Address, + wss: cfg.WSS, + } + + c.dialer = &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: handshakeTimeout, + } + + c.Header = http.Header{} + c.Header.Add(headerClientName, clientName) + + return c +} + +func getGameWebSocketPath(gameId int) string { + return fmt.Sprintf(gameWebsocketPathFormat, gameId) +} + +func (c *Connector) getGameWebWocketURL(gameId int) *url.URL { + u := &url.URL{ + Host: c.address, + Scheme: websocketInsecureScheme, + } + + if c.wss { + u.Scheme = websocketSecureScheme + } + + u.Path = getGameWebSocketPath(gameId) + + return u +} + +func (c *Connector) Connect(ctx context.Context, gameId int) (Connection, error) { + u := c.getGameWebWocketURL(gameId) + ws, resp, err := c.dialer.DialContext(ctx, u.String(), c.Header) + if err != nil { + return nil, errors.Wrap(err, "cannot connect") + } + conn := &connection{ + conn: ws, + resp: resp, + mux: &sync.Mutex{}, + } + return conn, nil +} diff --git a/internal/core/bot_operator.go b/internal/core/bot_operator.go new file mode 100644 index 0000000..0f84060 --- /dev/null +++ b/internal/core/bot_operator.go @@ -0,0 +1,129 @@ +package core + +import ( + "context" + "sync" + "time" + + "github.com/ivan1993spb/snake-bot/internal/bot" + "github.com/ivan1993spb/snake-bot/internal/connect" + "github.com/ivan1993spb/snake-bot/internal/parser" + "github.com/ivan1993spb/snake-bot/internal/utils" +) + +type BotOperator struct { + gameId int + + connector Connector + + game *bot.Game + bot *bot.Bot + parser *parser.Parser + + stop chan struct{} + once sync.Once +} + +func NewBotOperator(ctx context.Context, gameId int, + connector Connector) *BotOperator { + g := bot.NewGame() + b := bot.NewDijkstrasBot(g) + p := &parser.Parser{ + Countdown: b, + Me: b, + Size: g, + Game: g, + Printer: utils.NewPrinterLog(utils.Log(ctx), gameId), + } + + return &BotOperator{ + gameId: gameId, + + connector: connector, + + game: g, + bot: b, + parser: p, + + stop: make(chan struct{}), + once: sync.Once{}, + } +} + +const botConnectRetryTimeout = time.Second * 10 + +func (bo *BotOperator) Run(ctx context.Context) { + go func() { + + retry: + conn, err := bo.connector.Connect(ctx, bo.gameId) + if err != nil { + utils.Log(ctx).WithError(err).Error("connect fail") + + select { + case <-bo.stop: + return + case <-ctx.Done(): + return + case <-time.After(botConnectRetryTimeout): + } + + goto retry + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + for direction := range bo.bot.Run(ctx, bo.stop) { + message := direction.ToMessageSnakeCommand() + if err := conn.Send(message); err != nil { + utils.Log(ctx).WithError(err).Error( + "connection send fail") + } + } + }() + + go func() { + defer wg.Done() + for { + message, err := conn.Receive() + if err == connect.ErrWrongMsgType { + // With a wrong message type we just skip the + // message. + continue + } + if err != nil { + if err != connect.ErrConnectionClosed { + utils.Log(ctx).WithError(err).Error( + "connection receive fail") + } + break + } + if err := bo.parser.Parse(message); err != nil { + utils.Log(ctx).WithError(err).Error( + "parse message fail") + } + } + }() + + select { + case <-bo.stop: + case <-ctx.Done(): + } + + if err := conn.Close(); err != nil { + utils.Log(ctx).WithError(err).Error("connection close fail") + } + + wg.Wait() + bo.Stop() + }() +} + +func (bo *BotOperator) Stop() { + bo.once.Do(func() { + close(bo.stop) + }) +} diff --git a/internal/core/core.go b/internal/core/core.go new file mode 100644 index 0000000..a203bfc --- /dev/null +++ b/internal/core/core.go @@ -0,0 +1,114 @@ +package core + +import ( + "context" + "sync" + + "github.com/pkg/errors" + + "github.com/ivan1993spb/snake-bot/internal/connect" +) + +type Connector interface { + Connect(ctx context.Context, gameId int) (connect.Connection, error) +} + +type Core struct { + ctx context.Context + + connector Connector + + mux *sync.Mutex + bots map[int][]*BotOperator + + botsLimit int +} + +func NewCore(ctx context.Context, connector Connector, botsLimit int) *Core { + return &Core{ + ctx: ctx, + + connector: connector, + + mux: &sync.Mutex{}, + bots: make(map[int][]*BotOperator), + + botsLimit: botsLimit, + } +} + +func (c *Core) state() map[int]int { + c.mux.Lock() + defer c.mux.Unlock() + return c.unsafeState() +} + +func (c *Core) unsafeState() map[int]int { + state := make(map[int]int, len(c.bots)) + for gameId, bots := range c.bots { + state[gameId] = len(bots) + } + return state +} + +var errRequestedTooManyBots = errors.New("requested too many bots") + +func (c *Core) ApplyState(state map[int]int) (map[int]int, error) { + if stateBotsNumber(state) > c.botsLimit { + return nil, errRequestedTooManyBots + } + + c.mux.Lock() + defer c.mux.Unlock() + + d := diff(c.unsafeState(), state) + c.unsafeApplyDiff(d) + + return c.unsafeState(), nil +} + +func stateBotsNumber(state map[int]int) (number int) { + for _, bots := range state { + number += bots + } + return +} + +func (c *Core) unsafeApplyDiff(d map[int]int) { + for gameId, bots := range d { + if bots > 0 { + for i := 0; i < bots; i++ { + bo := NewBotOperator(c.ctx, gameId, c.connector) + bo.Run(c.ctx) + c.bots[gameId] = append(c.bots[gameId], bo) + } + } else { + for i := 0; i < -bots; i++ { + var bo *BotOperator + bo, c.bots[gameId] = c.bots[gameId][0], c.bots[gameId][1:] + bo.Stop() + } + } + } +} + +func (c *Core) SetupOne(gameId, bots int) (map[int]int, error) { + c.mux.Lock() + defer c.mux.Unlock() + + state := c.unsafeState() + state[gameId] = bots + + if stateBotsNumber(state) > c.botsLimit { + return nil, errRequestedTooManyBots + } + + d := diff(c.unsafeState(), state) + c.unsafeApplyDiff(d) + + return c.unsafeState(), nil +} + +func (c *Core) ReadState() map[int]int { + return c.state() +} diff --git a/internal/core/diff.go b/internal/core/diff.go new file mode 100644 index 0000000..b64de96 --- /dev/null +++ b/internal/core/diff.go @@ -0,0 +1,27 @@ +package core + +func diff(have, want map[int]int) map[int]int { + d := make(map[int]int) + + for k, v1 := range have { + if v2, ok := want[k]; ok { + if v := v2 - v1; v != 0 { + d[k] = v + } + } else { + d[k] = -v1 + } + } + + for k, v := range want { + if _, ok := d[k]; ok { + continue + } + if _, ok := have[k]; ok { + continue + } + d[k] = v + } + + return d +} diff --git a/internal/core/diff_test.go b/internal/core/diff_test.go new file mode 100644 index 0000000..63047e9 --- /dev/null +++ b/internal/core/diff_test.go @@ -0,0 +1,85 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDiff(t *testing.T) { + const errFormat = "test case %d" + + tests := []struct { + have, want, diff map[int]int + }{ + { + have: map[int]int{ + 1: 1, + 2: 1, + 3: 1, + 7: 1, + 9: 5, + }, + want: map[int]int{ + 2: 4, + 8: 5, + 9: 3, + }, + diff: map[int]int{ + 1: -1, + 2: 3, + 3: -1, + 7: -1, + 8: 5, + 9: -2, + }, + }, + { + have: map[int]int{ + 1: 1, + 2: 5, + 3: 1, + }, + want: map[int]int{ + 4: 4, + 5: 5, + 6: 3, + }, + diff: map[int]int{ + 1: -1, + 2: -5, + 3: -1, + 4: 4, + 5: 5, + 6: 3, + }, + }, + { + have: map[int]int{ + 1: 3, + }, + want: map[int]int{ + 1: 3, + }, + diff: map[int]int{}, + }, + { + have: map[int]int{ + 1: 10, + 2: 8, + }, + want: map[int]int{ + 1: 10, + 2: 12, + }, + diff: map[int]int{ + 2: 4, + }, + }, + } + + for i, test := range tests { + actualDiff := diff(test.have, test.want) + assert.Equal(t, test.diff, actualDiff, errFormat, i+1) + } +} diff --git a/internal/parser/parser.go b/internal/parser/parser.go new file mode 100644 index 0000000..03ce1cc --- /dev/null +++ b/internal/parser/parser.go @@ -0,0 +1,168 @@ +package parser + +import ( + "encoding/json" + + "github.com/pkg/errors" + + "github.com/ivan1993spb/snake-bot/internal/types" +) + +type Countdown interface { + Countdown(sec int) +} + +type Me interface { + Me(id uint32) +} + +type Size interface { + Size(width, height uint8) +} + +type Game interface { + Create(object *types.Object) + Update(object *types.Object) + Delete(object *types.Object) +} + +type Printer interface { + Print(level, what, message string) +} + +type Parser struct { + Countdown + Me + Size + Game + Printer +} + +const errParseAnnotation = "parsing error" + +var ErrParseUnknownMessage = errors.Errorf("%s: %s", errParseAnnotation, + "unknown message") + +func (p *Parser) Parse(data []byte) error { + var message *types.Message + if err := json.Unmarshal(data, &message); err != nil { + return errors.Wrap(err, errParseAnnotation) + } + switch message.Type { + case types.MessageTypeGameEvent: + if err := p.ParseGameEvent(message.Payload); err != nil { + return errors.Wrap(err, errParseAnnotation) + } + case types.MessageTypePlayer: + if err := p.ParsePlayerEvent(message.Payload); err != nil { + return errors.Wrap(err, errParseAnnotation) + } + case types.MessageTypeBroadcast: + if err := p.ParseBroadcast(message.Payload); err != nil { + return errors.Wrap(err, errParseAnnotation) + } + default: + return errors.WithMessagef(ErrParseUnknownMessage, "unkown type %q", + message.Type) + } + return nil +} + +const errParseGameEventAnnotation = "cannot parse game event" + +func (p *Parser) ParseGameEvent(data []byte) error { + var event *types.GameEvent + if err := json.Unmarshal(data, &event); err != nil { + return errors.Wrap(err, errParseGameEventAnnotation) + } + if event.Type == types.GameEventTypeError { + var errorMessage string + if err := json.Unmarshal(event.Payload, &errorMessage); err != nil { + return errors.Wrap(err, errParseGameEventAnnotation) + } + p.Printer.Print("game", "error", errorMessage) + return nil + } + var object *types.Object + if err := json.Unmarshal(event.Payload, &object); err != nil { + return errors.Wrap(err, errParseGameEventAnnotation) + } + switch event.Type { + case types.GameEventTypeCreate: + p.Game.Create(object) + case types.GameEventTypeDelete: + p.Game.Delete(object) + case types.GameEventTypeUpdate: + p.Game.Update(object) + case types.GameEventTypeChecked: + // ignore deprecated feature. + } + return nil +} + +const errParsePlayerEventAnnotation = "cannot parse player event" + +var ErrParseUnknownPlayerEvent = errors.Errorf("%s: %s", errParsePlayerEventAnnotation, + "unknown player event") + +func (p *Parser) ParsePlayerEvent(data []byte) error { + var event *types.PlayerEvent + if err := json.Unmarshal(data, &event); err != nil { + return errors.Wrap(err, errParsePlayerEventAnnotation) + } + switch event.Type { + case types.PlayerEventTypeSize: + var size *types.Size + if err := json.Unmarshal(event.Payload, &size); err != nil { + return errors.Wrap(err, errParsePlayerEventAnnotation) + } + p.Size.Size(size.Width, size.Height) + case types.PlayerEventTypeSnake: + var me uint32 + if err := json.Unmarshal(event.Payload, &me); err != nil { + return errors.Wrap(err, errParsePlayerEventAnnotation) + } + p.Me.Me(me) + case types.PlayerEventTypeNotice: + var notice string + if err := json.Unmarshal(event.Payload, ¬ice); err != nil { + return errors.Wrap(err, errParsePlayerEventAnnotation) + } + p.Printer.Print("player", "notice", notice) + case types.PlayerEventTypeError: + var errorMessage string + if err := json.Unmarshal(event.Payload, &errorMessage); err != nil { + return errors.Wrap(err, errParsePlayerEventAnnotation) + } + p.Printer.Print("player", "error", errorMessage) + case types.PlayerEventTypeCountdown: + var countdown int + if err := json.Unmarshal(event.Payload, &countdown); err != nil { + return errors.Wrap(err, errParsePlayerEventAnnotation) + } + p.Countdown.Countdown(countdown) + case types.PlayerEventTypeObjects: + var objects []*types.Object + if err := json.Unmarshal(event.Payload, &objects); err != nil { + return errors.Wrap(err, errParsePlayerEventAnnotation) + } + for _, object := range objects { + p.Game.Create(object) + } + default: + return errors.WithMessagef(ErrParseUnknownPlayerEvent, "unkown type %q", + event.Type) + } + return nil +} + +const errParseBroadcastAnnotation = "cannot parse broadcast message" + +func (p *Parser) ParseBroadcast(data []byte) error { + var message string + if err := json.Unmarshal(data, &message); err != nil { + return errors.Wrap(err, errParseBroadcastAnnotation) + } + p.Printer.Print("broadcast", "message", message) + return nil +} diff --git a/internal/parser/parser_mocks_test.go b/internal/parser/parser_mocks_test.go new file mode 100644 index 0000000..578374b --- /dev/null +++ b/internal/parser/parser_mocks_test.go @@ -0,0 +1,55 @@ +package parser + +import ( + "github.com/stretchr/testify/mock" + + "github.com/ivan1993spb/snake-bot/internal/types" +) + +type MockCountdown struct { + mock.Mock +} + +func (m *MockCountdown) Countdown(sec int) { + m.Called(sec) +} + +type MockMe struct { + mock.Mock +} + +func (m *MockMe) Me(id uint32) { + m.Called(id) +} + +type MockSize struct { + mock.Mock +} + +func (m *MockSize) Size(width, height uint8) { + m.Called(width, height) +} + +type MockGame struct { + mock.Mock +} + +func (m *MockGame) Create(object *types.Object) { + m.Called(object) +} + +func (m *MockGame) Update(object *types.Object) { + m.Called(object) +} + +func (m *MockGame) Delete(object *types.Object) { + m.Called(object) +} + +type MockPrinter struct { + mock.Mock +} + +func (m *MockPrinter) Print(level, what, message string) { + m.Called(level, what, message) +} diff --git a/internal/parser/parser_test.go b/internal/parser/parser_test.go new file mode 100644 index 0000000..e4fd0dd --- /dev/null +++ b/internal/parser/parser_test.go @@ -0,0 +1,39 @@ +package parser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type ParserTestSuite struct { + suite.Suite + Parser *Parser +} + +func (suite *ParserTestSuite) SetupTest() { + suite.Parser = &Parser{} +} + +var messageBroadcastPayload = []byte(`"user joined your game group"`) + +var messageBroadcastPayloadCorrupted = []byte(`user joined your game group"`) + +func (suite *ParserTestSuite) TestParseBroadcast() { + printer := new(MockPrinter) + suite.Parser.Printer = printer + printer.On("Print", "broadcast", "message", "user joined your game group").Return() + err := suite.Parser.ParseBroadcast(messageBroadcastPayload) + assert.Nil(suite.T(), err) + printer.AssertExpectations(suite.T()) +} + +func (suite *ParserTestSuite) TestParseBroadcastCorruptedPayload() { + err := suite.Parser.ParseBroadcast(messageBroadcastPayloadCorrupted) + assert.NotNil(suite.T(), err) +} + +func TestParserTestSuite(t *testing.T) { + suite.Run(t, new(ParserTestSuite)) +} diff --git a/internal/secure/secure.go b/internal/secure/secure.go new file mode 100644 index 0000000..7f0f70e --- /dev/null +++ b/internal/secure/secure.go @@ -0,0 +1,74 @@ +package secure + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "io" + "sync" + + "github.com/pkg/errors" + "lukechampine.com/noescape" +) + +type Secure struct { + mux *sync.Mutex + sum [sha256.Size]byte +} + +func NewSecure() *Secure { + return &Secure{ + mux: &sync.Mutex{}, + } +} + +const bufferSize = 48 + +var ( + messageTokenBegin = []byte("Auth token: ") + messageTokenEnd = []byte("\n") +) + +func (s *Secure) GenerateToken(w io.Writer) error { + buffer := make([]byte, bufferSize) + + // Doesn't let the buffer with a token escape to heap. + if _, err := noescape.Read(rand.Reader, buffer); err != nil { + return errors.Wrap(err, "generate token") + } + + sum := sha256.Sum256(buffer) + s.mux.Lock() + s.sum = sum + s.mux.Unlock() + + if _, err := w.Write(messageTokenBegin); err != nil { + return errors.Wrap(err, "fail to warn") + } + enc := base64.NewEncoder(base64.RawURLEncoding, w) + if _, err := noescape.Write(enc, buffer); err != nil { + enc.Close() + return errors.Wrap(err, "write auth token") + } + enc.Close() + if _, err := w.Write(messageTokenEnd); err != nil { + return errors.Wrap(err, "fail to warn") + } + + return nil +} + +func (s *Secure) VerifyToken(token string) bool { + decoded, err := base64.RawURLEncoding.DecodeString(token) + if err != nil { + return false + } + + flag := false + sum := sha256.Sum256(decoded) + s.mux.Lock() + flag = sum == s.sum + s.mux.Unlock() + + return flag +} diff --git a/internal/server/handlers/apply_state_handler.go b/internal/server/handlers/apply_state_handler.go new file mode 100644 index 0000000..c72d6b9 --- /dev/null +++ b/internal/server/handlers/apply_state_handler.go @@ -0,0 +1,131 @@ +package handlers + +import ( + "encoding/json" + "mime" + "net/http" + "strconv" + + "github.com/pkg/errors" + "gopkg.in/yaml.v2" + + "github.com/ivan1993spb/snake-bot/internal/server/models" + "github.com/ivan1993spb/snake-bot/internal/utils" +) + +type AppApplyState interface { + ApplyState(state map[int]int) (map[int]int, error) + SetupOne(gameId, botsNumber int) (map[int]int, error) +} + +func NewApplyStateHandler(app AppApplyState) http.HandlerFunc { + const ( + mediaTypeFormUrlencoded = "application/x-www-form-urlencoded" + mediaTypeJson = "application/json" + mediaTypeYaml = "text/yaml" + ) + + applyState := func(games *models.Games) (map[int]int, error, int) { + requested := games.ToMapState() + state, err := app.ApplyState(requested) + if err != nil { + return nil, errors.Wrap(err, "apply state fail"), + http.StatusInternalServerError + } + return state, nil, http.StatusCreated + } + + setupOne := func(gameId, bots int) (map[int]int, error, int) { + state, err := app.SetupOne(gameId, bots) + if err != nil { + return nil, errors.Wrap(err, "setup one fail"), + http.StatusInternalServerError + } + return state, err, http.StatusCreated + } + + handleFormUrlencoded := func(r *http.Request) (map[int]int, error, int) { + if err := r.ParseForm(); err != nil { + return nil, errors.Wrap(err, "parse form fail"), + http.StatusBadRequest + } + gameId, err := strconv.Atoi(r.PostForm.Get("game")) + if err != nil { + return nil, errors.Wrap(err, "parse game id"), + http.StatusBadRequest + } + bots, err := strconv.Atoi(r.PostForm.Get("bots")) + if err != nil { + return nil, errors.Wrap(err, "parse bots number"), + http.StatusBadRequest + } + return setupOne(gameId, bots) + } + + handleJson := func(r *http.Request) (map[int]int, error, int) { + var games *models.Games + if err := json.NewDecoder(r.Body).Decode(&games); err != nil { + return nil, errors.Wrap(err, "decode json fail"), + http.StatusBadRequest + } + return applyState(games) + } + + handleYaml := func(r *http.Request) (map[int]int, error, int) { + var games *models.Games + if err := yaml.NewDecoder(r.Body).Decode(&games); err != nil { + return nil, errors.Wrap(err, "decode yaml fail"), + http.StatusBadRequest + } + return applyState(games) + } + + handlersMapping := map[string]func(r *http.Request) (map[int]int, error, int){ + mediaTypeFormUrlencoded: handleFormUrlencoded, + mediaTypeJson: handleJson, + mediaTypeYaml: handleYaml, + } + + return func(w http.ResponseWriter, r *http.Request) { + contentType := r.Header.Get("Content-type") + acceptType := r.Header.Get("Accept") + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + utils.Log(r.Context()).WithError(err).Error("parse media type") + w.WriteHeader(http.StatusBadRequest) + return + } + + handler, ok := handlersMapping[mediaType] + if !ok { + utils.Log(r.Context()).WithField( + "media_type", mediaType).Error("unknown media type") + w.WriteHeader(http.StatusBadRequest) + return + } + + state, err, statusCode := handler(r) + if err != nil { + utils.Log(r.Context()).WithError(err).Error("handle fail") + w.WriteHeader(statusCode) + return + } + + response := models.NewGames(state) + + if acceptType == mediaTypeYaml { + w.Header().Set("Content-Type", mediaTypeYaml) + w.WriteHeader(statusCode) + if err := yaml.NewEncoder(w).Encode(response); err != nil { + utils.Log(r.Context()).WithError(err).Error("write yaml response") + } + return + } + + w.Header().Set("Content-Type", mediaTypeJson) + w.WriteHeader(statusCode) + if err := json.NewEncoder(w).Encode(response); err != nil { + utils.Log(r.Context()).WithError(err).Error("write json response") + } + } +} diff --git a/internal/server/handlers/openapi_handler.go b/internal/server/handlers/openapi_handler.go new file mode 100644 index 0000000..295dc7a --- /dev/null +++ b/internal/server/handlers/openapi_handler.go @@ -0,0 +1,17 @@ +package handlers + +import ( + "net/http" + + "github.com/ivan1993spb/snake-bot" + "github.com/ivan1993spb/snake-bot/internal/utils" +) + +func OpenAPIHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/yaml; charset=utf-8") + w.WriteHeader(http.StatusOK) + + if _, err := w.Write([]byte(snakebot.OpenAPISpec)); err != nil { + utils.Log(r.Context()).WithError(err).Error("openapi handler fail") + } +} diff --git a/internal/server/handlers/read_state_handler.go b/internal/server/handlers/read_state_handler.go new file mode 100644 index 0000000..d2ee313 --- /dev/null +++ b/internal/server/handlers/read_state_handler.go @@ -0,0 +1,45 @@ +package handlers + +import ( + "encoding/json" + "net/http" + + "gopkg.in/yaml.v2" + + "github.com/ivan1993spb/snake-bot/internal/server/models" + "github.com/ivan1993spb/snake-bot/internal/utils" +) + +type AppReadState interface { + ReadState() map[int]int +} + +func NewReadStateHandler(app AppReadState) http.HandlerFunc { + const ( + typeJson = "application/json" + typeYaml = "text/yaml" + ) + return func(w http.ResponseWriter, r *http.Request) { + accept := r.Header.Get("Accept") + state := app.ReadState() + response := models.NewGames(state) + + if accept == typeYaml { + w.Header().Set("Content-Type", typeYaml) + w.WriteHeader(http.StatusOK) + + err := yaml.NewEncoder(w).Encode(response) + if err != nil { + utils.Log(r.Context()).WithError(err).Error("write yaml response") + } + return + } + + w.Header().Set("Content-Type", typeJson) + w.WriteHeader(http.StatusOK) + err := json.NewEncoder(w).Encode(response) + if err != nil { + utils.Log(r.Context()).WithError(err).Error("write json response") + } + } +} diff --git a/internal/server/handlers/welcome_handler.go b/internal/server/handlers/welcome_handler.go new file mode 100644 index 0000000..7026bf9 --- /dev/null +++ b/internal/server/handlers/welcome_handler.go @@ -0,0 +1,18 @@ +package handlers + +import ( + "net/http" + + "github.com/ivan1993spb/snake-bot/internal/utils" +) + +var welcomeMessage = []byte(`Welcome to Snake-Bot!`) + +func WelcomeHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusOK) + + if _, err := w.Write(welcomeMessage); err != nil { + utils.Log(r.Context()).WithError(err).Error("welcome handler fail") + } +} diff --git a/internal/server/middlewares/logger.go b/internal/server/middlewares/logger.go new file mode 100644 index 0000000..d460c1a --- /dev/null +++ b/internal/server/middlewares/logger.go @@ -0,0 +1,68 @@ +package middlewares + +import ( + "fmt" + "net/http" + "time" + + "github.com/go-chi/chi/v5/middleware" + "github.com/sirupsen/logrus" +) + +func NewRequestLogger(logger *logrus.Logger) func(next http.Handler) http.Handler { + return middleware.RequestLogger(&RequestLogger{logger}) +} + +type RequestLogger struct { + Logger *logrus.Logger +} + +func (l *RequestLogger) NewLogEntry(r *http.Request) middleware.LogEntry { + entry := &StructuredLoggerEntry{Logger: logrus.NewEntry(l.Logger)} + logFields := logrus.Fields{} + + if reqID := middleware.GetReqID(r.Context()); reqID != "" { + logFields["req_id"] = reqID + } + + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + logFields["http_scheme"] = scheme + logFields["http_proto"] = r.Proto + logFields["http_method"] = r.Method + + logFields["remote_addr"] = r.RemoteAddr + logFields["user_agent"] = r.UserAgent() + + logFields["uri"] = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI) + + entry.Logger = entry.Logger.WithFields(logFields) + + entry.Logger.Infoln("request started") + + return entry +} + +type StructuredLoggerEntry struct { + Logger logrus.FieldLogger +} + +func (l *StructuredLoggerEntry) Write(status, bytes int, header http.Header, + elapsed time.Duration, extra interface{}) { + l.Logger = l.Logger.WithFields(logrus.Fields{ + "resp_status": status, + "resp_bytes_length": bytes, + "resp_elapsed_ms": float64(elapsed.Nanoseconds()) / 1000000.0, + }) + + l.Logger.Infoln("request complete") +} + +func (l *StructuredLoggerEntry) Panic(v interface{}, stack []byte) { + l.Logger = l.Logger.WithFields(logrus.Fields{ + "stack": string(stack), + "panic": fmt.Sprintf("%+v", v), + }) +} diff --git a/internal/server/middlewares/token_auth.go b/internal/server/middlewares/token_auth.go new file mode 100644 index 0000000..c189823 --- /dev/null +++ b/internal/server/middlewares/token_auth.go @@ -0,0 +1,22 @@ +package middlewares + +import "net/http" + +const headerToken = "X-Snake-Bot-Token" + +type Secure interface { + VerifyToken(token string) bool +} + +func TokenAuth(sec Secure) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if token := r.Header.Get(headerToken); !sec.VerifyToken(token) { + w.WriteHeader(http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/server/models/error.go b/internal/server/models/error.go new file mode 100644 index 0000000..d782f8d --- /dev/null +++ b/internal/server/models/error.go @@ -0,0 +1,6 @@ +package models + +type Error struct { + Code int `json:"code",yaml:"code"` + Text string `json:"text",yaml:"text"` +} diff --git a/internal/server/models/game.go b/internal/server/models/game.go new file mode 100644 index 0000000..19f1a42 --- /dev/null +++ b/internal/server/models/game.go @@ -0,0 +1,6 @@ +package models + +type Game struct { + Game int `json:"game",yaml:"game"` + Bots int `json:"bots",yaml:"bots"` +} diff --git a/internal/server/models/games.go b/internal/server/models/games.go new file mode 100644 index 0000000..128265c --- /dev/null +++ b/internal/server/models/games.go @@ -0,0 +1,27 @@ +package models + +type Games struct { + Games []*Game `json:"games",yaml:"games"` +} + +func NewGames(state map[int]int) *Games { + g := &Games{} + g.Games = make([]*Game, 0, len(state)) + for game, bots := range state { + if bots > 0 { + g.Games = append(g.Games, &Game{ + Game: game, + Bots: bots, + }) + } + } + return g +} + +func (g *Games) ToMapState() map[int]int { + state := make(map[int]int, len(g.Games)) + for _, game := range g.Games { + state[game.Game] = game.Bots + } + return state +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..a7c2612 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,134 @@ +package server + +import ( + "context" + "net" + "net/http" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/cors" + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus/promhttp" + + "github.com/ivan1993spb/snake-bot/internal/config" + "github.com/ivan1993spb/snake-bot/internal/server/handlers" + "github.com/ivan1993spb/snake-bot/internal/server/middlewares" + "github.com/ivan1993spb/snake-bot/internal/utils" +) + +type Core interface { + ApplyState(state map[int]int) (map[int]int, error) + SetupOne(gameId, botsNumber int) (map[int]int, error) + ReadState() map[int]int +} + +type Secure interface { + VerifyToken(token string) bool +} + +type Server struct { + ctx context.Context + + server *http.Server + + core Core + sec Secure +} + +func NewServer(ctx context.Context, cfg config.Server, appInfo string, + core Core, sec Secure) *Server { + s := &Server{ + server: &http.Server{ + Addr: cfg.Address, + BaseContext: func(net.Listener) context.Context { + return ctx + }, + }, + + core: core, + sec: sec, + } + + s.initRoutes(ctx, cfg.Debug, cfg.ForbidCORS, appInfo) + + return s +} + +const requestPostBotsThrottleLimit = 1 + +func (s *Server) initRoutes(ctx context.Context, debug, forbidCORS bool, + appInfo string) { + r := chi.NewRouter() + + r.Use(middleware.Recoverer) + r.Use(middleware.RealIP) + r.Use(middleware.RequestID) + r.Use(middlewares.NewRequestLogger(utils.Log(ctx))) + r.Use(middleware.SetHeader("Server", appInfo)) + r.Use(middleware.GetHead) + // By default origins (domain, scheme, or port) don't matter. + if !forbidCORS { + r.Use(cors.AllowAll().Handler) + } + + r.Get("/", handlers.WelcomeHandler) + r.With(middleware.NoCache).Get("/openapi.yaml", handlers.OpenAPIHandler) + r.Route("/api/bots", func(r chi.Router) { + r.Use(middlewares.TokenAuth(s.sec)) + r.With( + middleware.AllowContentType( + "application/x-www-form-urlencoded", + "application/json", + "text/yaml", + ), + middleware.Throttle(requestPostBotsThrottleLimit), + ).Post("/", handlers.NewApplyStateHandler(s.core)) + r.Get("/", handlers.NewReadStateHandler(s.core)) + }) + + if debug { + r.Mount("/debug", middleware.Profiler()) + } + r.Handle("/metrics", promhttp.Handler()) + + s.server.Handler = r +} + +const serverShutdownTimeout = time.Second + +const fieldShutdownTimeout = "shutdown_timeout" + +func (s *Server) ListenAndServe(ctx context.Context) error { + go func() { + <-ctx.Done() + + log := utils.Log(ctx).WithField(fieldShutdownTimeout, serverShutdownTimeout) + log.Info("shutting down") + + shutdownCtx, cancel := context.WithTimeout( + context.Background(), + serverShutdownTimeout, + ) + defer cancel() + + go func() { + <-shutdownCtx.Done() + if shutdownCtx.Err() == context.DeadlineExceeded { + log.Error("graceful shutdown timed out") + log.Fatal("forcing exit") + } + }() + + if err := s.server.Shutdown(shutdownCtx); err != nil { + log.WithError(err).Error("server shutdown fail") + } + }() + + err := s.server.ListenAndServe() + if err == http.ErrServerClosed { + return nil + } + return errors.Wrap(err, "listen and serve") +} diff --git a/internal/types/types.go b/internal/types/types.go new file mode 100644 index 0000000..b175a7e --- /dev/null +++ b/internal/types/types.go @@ -0,0 +1,243 @@ +package types + +import ( + "bytes" + "encoding/json" + "fmt" + "strconv" + + "github.com/pkg/errors" +) + +type Direction string + +const ( + DirectionZero Direction = "zero" + DirectionNorth Direction = "north" + DirectionWest Direction = "west" + DirectionSouth Direction = "south" + DirectionEast Direction = "east" +) + +type MessageType string + +const ( + MessageTypeGameEvent MessageType = "game" + MessageTypePlayer MessageType = "player" + MessageTypeBroadcast MessageType = "broadcast" +) + +type Message struct { + Type MessageType `json:"type"` + Payload json.RawMessage `json:"payload"` +} + +type GameEventType string + +const ( + GameEventTypeError GameEventType = "error" + GameEventTypeCreate GameEventType = "create" + GameEventTypeDelete GameEventType = "delete" + GameEventTypeUpdate GameEventType = "update" + GameEventTypeChecked GameEventType = "checked" +) + +type GameEvent struct { + Type GameEventType `json:"type"` + Payload json.RawMessage `json:"payload"` +} + +type PlayerEventType string + +const ( + PlayerEventTypeSize PlayerEventType = "size" + PlayerEventTypeSnake PlayerEventType = "snake" + PlayerEventTypeNotice PlayerEventType = "notice" + PlayerEventTypeError PlayerEventType = "error" + PlayerEventTypeCountdown PlayerEventType = "countdown" + PlayerEventTypeObjects PlayerEventType = "objects" +) + +type PlayerEvent struct { + Type PlayerEventType `json:"type"` + Payload json.RawMessage `json:"payload"` +} + +type Dot struct { + X uint8 + Y uint8 +} + +func (d Dot) String() string { + return fmt.Sprintf("[%d, %d]", d.X, d.Y) +} + +const ( + dotFieldX = iota + dotFieldY +) + +const errDotUnmarshalJSONAnnotation = "cannot decode the dot" + +var ( + ErrDotUnmarshalJSONInsufficientInput = errors.Errorf("%s: %s", + errDotUnmarshalJSONAnnotation, "insufficient input") + ErrDotUnmarshalJSONInvalidStructure = errors.Errorf("%s: %s", + errDotUnmarshalJSONAnnotation, "invalid structure") +) + +const ( + dotTokenOpen = '[' + dotTokenClose = ']' +) + +func (d *Dot) UnmarshalJSON(b []byte) error { + if len(b) < 2 { + return ErrDotUnmarshalJSONInsufficientInput + } + var op, cl byte + op, b, cl = b[0], b[1:len(b)-1], b[len(b)-1] + if op != dotTokenOpen || cl != dotTokenClose { + return ErrDotUnmarshalJSONInvalidStructure + } + axes := bytes.SplitN(b, []byte{','}, 2) + if len(axes) != 2 { + return ErrDotUnmarshalJSONInvalidStructure + } + x, err := strconv.Atoi(string(axes[dotFieldX])) + if err != nil { + return errors.Wrap(err, errDotUnmarshalJSONAnnotation) + } + y, err := strconv.Atoi(string(axes[dotFieldY])) + if err != nil { + return errors.Wrap(err, errDotUnmarshalJSONAnnotation) + } + *d = Dot{ + X: uint8(x), + Y: uint8(y), + } + return nil +} + +type ObjectType uint8 + +const ( + ObjectTypeUnknown ObjectType = iota + ObjectTypeSnake + ObjectTypeApple + ObjectTypeCorpse + ObjectTypeMouse + ObjectTypeWatermelon + ObjectTypeWall +) + +var ( + ObjectTypeTokenSnake = []byte(`"snake"`) + ObjectTypeTokenApple = []byte(`"apple"`) + ObjectTypeTokenCorpse = []byte(`"corpse"`) + ObjectTypeTokenMouse = []byte(`"mouse"`) + ObjectTypeTokenWatermelon = []byte(`"watermelon"`) + ObjectTypeTokenWall = []byte(`"wall"`) +) + +func (t *ObjectType) UnmarshalJSON(b []byte) error { + switch { + case bytes.Equal(b, ObjectTypeTokenSnake): + *t = ObjectTypeSnake + case bytes.Equal(b, ObjectTypeTokenWatermelon): + *t = ObjectTypeWatermelon + case bytes.Equal(b, ObjectTypeTokenCorpse): + *t = ObjectTypeCorpse + case bytes.Equal(b, ObjectTypeTokenWall): + *t = ObjectTypeWall + case bytes.Equal(b, ObjectTypeTokenMouse): + *t = ObjectTypeMouse + case bytes.Equal(b, ObjectTypeTokenApple): + *t = ObjectTypeApple + default: + *t = ObjectTypeUnknown + } + return nil +} + +type Object struct { + Type ObjectType `json:"type"` + Id uint32 `json:"id"` + Dot Dot `json:"dot"` + Dots []Dot `json:"dots"` + Direction Direction `json:"direction"` +} + +func (o *Object) GetType() ObjectType { + return o.Type +} + +func (o *Object) GetDots() []Dot { + switch o.Type { + case ObjectTypeMouse, ObjectTypeApple: + return []Dot{o.Dot} + case ObjectTypeSnake, ObjectTypeCorpse, ObjectTypeWatermelon, + ObjectTypeWall: + return o.Dots + } + return []Dot{} +} + +type Size struct { + Width uint8 `json:"width"` + Height uint8 `json:"height"` +} + +type MessageSnakeCommand int8 + +const ( + MessageSnakeCommandNorth = iota + MessageSnakeCommandSouth + MessageSnakeCommandWest + MessageSnakeCommandEast +) + +const ( + messageSnakeCommandLabelNorth = "north" + messageSnakeCommandLabelSouth = "south" + messageSnakeCommandLabelWest = "west" + messageSnakeCommandLabelEast = "east" +) + +var messageSnakeCommandLabelMapping = map[MessageSnakeCommand]string{ + MessageSnakeCommandNorth: messageSnakeCommandLabelNorth, + MessageSnakeCommandSouth: messageSnakeCommandLabelSouth, + MessageSnakeCommandWest: messageSnakeCommandLabelWest, + MessageSnakeCommandEast: messageSnakeCommandLabelEast, +} + +func (m MessageSnakeCommand) String() string { + if label, ok := messageSnakeCommandLabelMapping[m]; ok { + return label + } + return "unknown" +} + +const messageSnakeCommandBuffer = 64 + +func (m MessageSnakeCommand) MarshalJSON() ([]byte, error) { + b := bytes.NewBuffer(make([]byte, 0, messageSnakeCommandBuffer)) + b.Write([]byte(`{"type":"snake","payload":"`)) + b.WriteString(m.String()) + b.Write([]byte(`"}`)) + return b.Bytes(), nil +} + +var directionMessageSnakeCommandMapping = map[Direction]MessageSnakeCommand{ + DirectionNorth: MessageSnakeCommandNorth, + DirectionSouth: MessageSnakeCommandSouth, + DirectionEast: MessageSnakeCommandEast, + DirectionWest: MessageSnakeCommandWest, +} + +func (d Direction) ToMessageSnakeCommand() MessageSnakeCommand { + if command, ok := directionMessageSnakeCommandMapping[d]; ok { + return command + } + return -1 +} diff --git a/internal/types/types_test.go b/internal/types/types_test.go new file mode 100644 index 0000000..680f5a6 --- /dev/null +++ b/internal/types/types_test.go @@ -0,0 +1,165 @@ +package types + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_Dot_UnmarshalJSON(t *testing.T) { + tests := []struct { + input []byte + expect Dot + err bool + }{ + { + input: []byte("1"), + expect: Dot{}, + err: true, + }, + { + input: []byte("11"), + expect: Dot{}, + err: true, + }, + { + input: []byte("[2,3]"), + expect: Dot{ + X: 2, + Y: 3, + }, + err: false, + }, + { + input: []byte("[255,255]"), + expect: Dot{ + X: 255, + Y: 255, + }, + err: false, + }, + { + input: nil, + expect: Dot{}, + err: true, + }, + { + input: []byte("sadsaasdsasad"), + expect: Dot{}, + err: true, + }, + { + input: []byte("[23]"), + expect: Dot{}, + err: true, + }, + { + input: []byte("[23,32,54]"), + expect: Dot{}, + err: true, + }, + { + input: []byte("[23, x]"), + expect: Dot{}, + err: true, + }, + { + input: []byte("[y, x]"), + expect: Dot{}, + err: true, + }, + { + input: []byte(`[123,"y"]`), + expect: Dot{}, + err: true, + }, + { + input: []byte(`["y","x"]`), + expect: Dot{}, + err: true, + }, + { + input: []byte(`"[y,x]"`), + expect: Dot{}, + err: true, + }, + { + input: []byte("[y,x]"), + expect: Dot{}, + err: true, + }, + } + + for i, test := range tests { + var d Dot + err := json.Unmarshal(test.input, &d) + if test.err { + require.NotNil(t, err, "Failed test number %d", i+1) + } else { + require.Nil(t, err, "Failed test number %d", i+1) + } + require.Equal(t, test.expect, d, "Failed test number %d", i+1) + } +} + +func Test_ObjectType_UnmarshalJSON(t *testing.T) { + const msgFormat = "Failed test number %d" + + tests := []struct { + input []byte + expected ObjectType + err bool + }{ + { + input: []byte(`"snake"`), + expected: ObjectTypeSnake, + err: false, + }, + { + input: []byte(`"watermelon"`), + expected: ObjectTypeWatermelon, + err: false, + }, + { + input: []byte(`"corpse"`), + expected: ObjectTypeCorpse, + err: false, + }, + { + input: []byte(`"wall"`), + expected: ObjectTypeWall, + err: false, + }, + { + input: []byte(`"mouse"`), + expected: ObjectTypeMouse, + err: false, + }, + { + input: []byte(`"apple"`), + expected: ObjectTypeApple, + err: false, + }, + { + input: []byte(`invalid value`), + err: true, + }, + { + input: []byte(`"unknown"`), + expected: ObjectTypeUnknown, + err: false, + }, + } + + for i, test := range tests { + var actualObjectType ObjectType + err := json.Unmarshal(test.input, &actualObjectType) + if test.err { + require.NotNil(t, err, msgFormat, i+1) + } else { + require.Nil(t, err, msgFormat, i+1) + } + require.Equal(t, test.expected, actualObjectType, msgFormat, i+1) + } +} diff --git a/internal/utils/client_name.go b/internal/utils/client_name.go new file mode 100644 index 0000000..af7ee72 --- /dev/null +++ b/internal/utils/client_name.go @@ -0,0 +1,7 @@ +package utils + +import "fmt" + +func FormatAppInfoHeader(name, version, build string) string { + return fmt.Sprintf("%s/%s (build %s)", name, version, build) +} diff --git a/internal/utils/context.go b/internal/utils/context.go new file mode 100644 index 0000000..747db82 --- /dev/null +++ b/internal/utils/context.go @@ -0,0 +1,24 @@ +package utils + +import ( + "context" + + "github.com/sirupsen/logrus" +) + +type contextKey struct{} + +var loggerKey = contextKey{} + +func LogContext(ctx context.Context, logger *logrus.Logger) context.Context { + return context.WithValue(ctx, loggerKey, logger) +} + +var defaultLog = logrus.New() + +func Log(ctx context.Context) *logrus.Logger { + if logger, ok := ctx.Value(loggerKey).(*logrus.Logger); ok { + return logger + } + return defaultLog +} diff --git a/internal/utils/logger.go b/internal/utils/logger.go new file mode 100644 index 0000000..df2d95a --- /dev/null +++ b/internal/utils/logger.go @@ -0,0 +1,34 @@ +package utils + +import ( + "time" + + "github.com/sirupsen/logrus" + + "github.com/ivan1993spb/snake-bot/internal/config" +) + +const ( + loggerDisableColors = true + loggerTimestampFormat = time.RFC1123 +) + +func NewLogger(cfg config.Log) *logrus.Logger { + logger := logrus.New() + if cfg.EnableJSON { + logger.Formatter = &logrus.JSONFormatter{ + TimestampFormat: loggerTimestampFormat, + } + } else { + logger.Formatter = &logrus.TextFormatter{ + DisableColors: loggerDisableColors, + TimestampFormat: loggerTimestampFormat, + } + } + if level, err := logrus.ParseLevel(cfg.Level); err != nil { + logger.SetLevel(logrus.InfoLevel) + } else { + logger.SetLevel(level) + } + return logger +} diff --git a/internal/utils/printer.go b/internal/utils/printer.go new file mode 100644 index 0000000..2a96ac3 --- /dev/null +++ b/internal/utils/printer.go @@ -0,0 +1,28 @@ +package utils + +import "github.com/sirupsen/logrus" + +type PrinterLog struct { + log logrus.FieldLogger +} + +const printerGameKey = "print_game" + +func NewPrinterLog(log logrus.FieldLogger, gameId int) *PrinterLog { + return &PrinterLog{ + log: log.WithField(printerGameKey, gameId), + } +} + +func (p *PrinterLog) Print(level, what, message string) { + p.log.WithFields(logrus.Fields{ + "level": level, + "what": what, + }).Debug(message) +} + +type PrinterDevNull struct{} + +func (p PrinterDevNull) Print(string, string, string) { + // Do nothing +}