From f6fcb5782b9fb74819bd9fc8f7ea8af4b740d284 Mon Sep 17 00:00:00 2001 From: joil Date: Thu, 21 Sep 2023 17:09:22 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E5=9F=BA?= =?UTF-8?q?=E4=BA=8E=20Redis=20=E7=9A=84=E6=BB=91=E5=8A=A8=E7=AA=97?= =?UTF-8?q?=E5=8F=A3=E9=99=90=E6=B5=81=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 + Makefile | 4 + go.mod | 41 +++ go.sum | 98 +++++++ internal/integration/ratelimit_test.go | 99 +++++++ internal/ratelimit/mocks/ratelimit.mock.go | 50 ++++ internal/ratelimit/redis_slide_window.go | 38 +++ internal/ratelimit/redis_slide_window_test.go | 64 +++++ internal/ratelimit/slide_window.lua | 22 ++ internal/ratelimit/types.go | 10 + middlewares/ratelimit/builder.go | 55 ++++ middlewares/ratelimit/builder_test.go | 242 ++++++++++++++++++ 12 files changed, 725 insertions(+) create mode 100644 Makefile create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/integration/ratelimit_test.go create mode 100644 internal/ratelimit/mocks/ratelimit.mock.go create mode 100644 internal/ratelimit/redis_slide_window.go create mode 100644 internal/ratelimit/redis_slide_window_test.go create mode 100644 internal/ratelimit/slide_window.lua create mode 100644 internal/ratelimit/types.go create mode 100644 middlewares/ratelimit/builder.go create mode 100644 middlewares/ratelimit/builder_test.go diff --git a/.gitignore b/.gitignore index 3b735ec..82c4ca1 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,5 @@ # Go workspace file go.work + +.idea \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..c3c37d9 --- /dev/null +++ b/Makefile @@ -0,0 +1,4 @@ +.PHONY: mock +mock: + @mockgen -source=internal/ratelimit/types.go -package=limitmocks -destination=internal/ratelimit/mocks/ratelimit.mock.go + @go mod tidy \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..379257d --- /dev/null +++ b/go.mod @@ -0,0 +1,41 @@ +module github.com/ecodeclub/ginx + +go 1.21.1 + +require ( + github.com/gin-gonic/gin v1.9.1 + github.com/redis/go-redis/v9 v9.1.0 + github.com/stretchr/testify v1.8.3 + go.uber.org/mock v0.3.0 +) + +require ( + github.com/bytedance/sonic v1.9.1 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect + golang.org/x/arch v0.3.0 // indirect + golang.org/x/crypto v0.9.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sys v0.8.0 // indirect + golang.org/x/text v0.9.0 // indirect + google.golang.org/protobuf v1.30.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6376821 --- /dev/null +++ b/go.sum @@ -0,0 +1,98 @@ +github.com/bsm/ginkgo/v2 v2.9.5 h1:rtVBYPs3+TC5iLUVOis1B9tjLTup7Cj5IfzosKtvTJ0= +github.com/bsm/ginkgo/v2 v2.9.5/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y= +github.com/bsm/gomega v1.26.0/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= +github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= +github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= +github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= +github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= +github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.1.0 h1:137FnGdk+EQdCbye1FW+qOEcY5S+SpY9T0NiuqvtfMY= +github.com/redis/go-redis/v9 v9.1.0/go.mod h1:urWj3He21Dj5k4TK1y59xH8Uj6ATueP8AH1cY3lZl4c= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= +go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= +golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/internal/integration/ratelimit_test.go b/internal/integration/ratelimit_test.go new file mode 100644 index 0000000..c53f646 --- /dev/null +++ b/internal/integration/ratelimit_test.go @@ -0,0 +1,99 @@ +//go:build e2e + +package integration + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + + "github.com/ecodeclub/ginx/internal/ratelimit" + limit "github.com/ecodeclub/ginx/middlewares/ratelimit" +) + +func TestBuilder_e2e_RateLimit(t *testing.T) { + const ( + ip = "127.0.0.1" + limitURL = "/limit" + ) + rdb := initRedis() + server := initWebServer(rdb) + RegisterRoutes(server) + + tests := []struct { + // 名字 + name string + // 要提前准备数据 + before func(t *testing.T) + // 验证并且删除数据 + after func(t *testing.T) + + // 预期响应 + wantCode int + }{ + { + name: "不限流", + before: func(t *testing.T) {}, + after: func(t *testing.T) { + rdb.Del(context.Background(), fmt.Sprintf("ip-limiter:%s", ip)) + }, + wantCode: http.StatusOK, + }, + { + name: "限流", + before: func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, limitURL, nil) + req.RemoteAddr = ip + ":80" + assert.NoError(t, err) + recorder := httptest.NewRecorder() + server.ServeHTTP(recorder, req) + }, + after: func(t *testing.T) { + rdb.Del(context.Background(), fmt.Sprintf("ip-limiter:%s", ip)) + }, + wantCode: http.StatusTooManyRequests, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer tt.after(t) + tt.before(t) + req, err := http.NewRequest(http.MethodGet, limitURL, nil) + req.RemoteAddr = ip + ":80" + assert.NoError(t, err) + + recorder := httptest.NewRecorder() + server.ServeHTTP(recorder, req) + + code := recorder.Code + assert.Equal(t, tt.wantCode, code) + }) + } +} + +func RegisterRoutes(server *gin.Engine) { + server.GET("/limit", func(ctx *gin.Context) { + ctx.Status(http.StatusOK) + }) +} + +func initRedis() redis.Cmdable { + redisClient := redis.NewClient(&redis.Options{ + Addr: "localhost:16379", + }) + return redisClient +} + +func initWebServer(cmd redis.Cmdable) *gin.Engine { + server := gin.Default() + limiter := ratelimit.NewRedisSlidingWindowLimiter(cmd, 500*time.Millisecond, 1) + server.Use(limit.NewBuilder(limiter).Build()) + return server +} diff --git a/internal/ratelimit/mocks/ratelimit.mock.go b/internal/ratelimit/mocks/ratelimit.mock.go new file mode 100644 index 0000000..5a34ce8 --- /dev/null +++ b/internal/ratelimit/mocks/ratelimit.mock.go @@ -0,0 +1,50 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: internal/ratelimit/types.go + +// Package limitmocks is a generated GoMock package. +package limitmocks + +import ( + "context" + "reflect" + + "go.uber.org/mock/gomock" +) + +// MockLimiter is a mock of Limiter interface. +type MockLimiter struct { + ctrl *gomock.Controller + recorder *MockLimiterMockRecorder +} + +// MockLimiterMockRecorder is the mock recorder for MockLimiter. +type MockLimiterMockRecorder struct { + mock *MockLimiter +} + +// NewMockLimiter creates a new mock instance. +func NewMockLimiter(ctrl *gomock.Controller) *MockLimiter { + mock := &MockLimiter{ctrl: ctrl} + mock.recorder = &MockLimiterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLimiter) EXPECT() *MockLimiterMockRecorder { + return m.recorder +} + +// Limit mocks base method. +func (m *MockLimiter) Limit(ctx context.Context, key string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Limit", ctx, key) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Limit indicates an expected call of Limit. +func (mr *MockLimiterMockRecorder) Limit(ctx, key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Limit", reflect.TypeOf((*MockLimiter)(nil).Limit), ctx, key) +} diff --git a/internal/ratelimit/redis_slide_window.go b/internal/ratelimit/redis_slide_window.go new file mode 100644 index 0000000..ff2396e --- /dev/null +++ b/internal/ratelimit/redis_slide_window.go @@ -0,0 +1,38 @@ +package ratelimit + +import ( + "context" + _ "embed" + "time" + + "github.com/redis/go-redis/v9" +) + +//go:embed slide_window.lua +var luaSlideWindow string + +// RedisSlidingWindowLimiter Redis 上的滑动窗口算法限流器实现 +type RedisSlidingWindowLimiter struct { + cmd redis.Cmdable + + // 窗口大小 + interval time.Duration + // 阈值 + rate int + // interval 内允许 rate 个请求 + // 1s 内允许 3000 个请求 +} + +func NewRedisSlidingWindowLimiter(cmd redis.Cmdable, + interval time.Duration, rate int) Limiter { + return &RedisSlidingWindowLimiter{ + cmd: cmd, + interval: interval, + rate: rate, + } +} + +func (r *RedisSlidingWindowLimiter) Limit(ctx context.Context, key string) (bool, error) { + return r.cmd.Eval(ctx, luaSlideWindow, []string{key}, + r.interval.Milliseconds(), r.rate, time.Now().UnixMilli()).Bool() +} diff --git a/internal/ratelimit/redis_slide_window_test.go b/internal/ratelimit/redis_slide_window_test.go new file mode 100644 index 0000000..2b4afde --- /dev/null +++ b/internal/ratelimit/redis_slide_window_test.go @@ -0,0 +1,64 @@ +package ratelimit + +import ( + "context" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +func TestRedisSlidingWindowLimiter_Limit(t *testing.T) { + r := NewRedisSlidingWindowLimiter(initRedis(), 500*time.Millisecond, 1) + tests := []struct { + name string + ctx context.Context + key string + interval time.Duration + want bool + wantErr error + }{ + { + name: "正常通过", + ctx: context.Background(), + key: "foo", + want: false, + }, + { + name: "另外一个key正常通过", + ctx: context.Background(), + key: "bar", + want: false, + }, + { + name: "限流", + ctx: context.Background(), + key: "foo", + interval: 300 * time.Millisecond, + want: true, + }, + { + name: "窗口有空余正常通过", + ctx: context.Background(), + key: "foo", + interval: 510 * time.Millisecond, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + <-time.After(tt.interval) + got, err := r.Limit(tt.ctx, tt.key) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func initRedis() redis.Cmdable { + redisClient := redis.NewClient(&redis.Options{ + Addr: "localhost:16379", + }) + return redisClient +} diff --git a/internal/ratelimit/slide_window.lua b/internal/ratelimit/slide_window.lua new file mode 100644 index 0000000..522f1be --- /dev/null +++ b/internal/ratelimit/slide_window.lua @@ -0,0 +1,22 @@ +-- 限流对象 +local key = KEYS[1] +-- 窗口大小 +local window = tonumber(ARGV[1]) +-- 阈值 +local threshold = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) +-- 窗口的起始时间 +local min = now - window + +redis.call('ZREMRANGEBYSCORE', key, '-inf', min) +local cnt = redis.call('ZCOUNT', key, '-inf', '+inf') +-- local cnt = redis.call('ZCOUNT', key, min, '+inf') +if cnt >= threshold then + -- 执行限流 + return "true" +else + -- 把 score 和 member 都设置成 now + redis.call('ZADD', key, now, now) + redis.call('PEXPIRE', key, window) + return "false" +end \ No newline at end of file diff --git a/internal/ratelimit/types.go b/internal/ratelimit/types.go new file mode 100644 index 0000000..8a43df5 --- /dev/null +++ b/internal/ratelimit/types.go @@ -0,0 +1,10 @@ +package ratelimit + +import "context" + +type Limiter interface { + // Limit 有没有触发限流。key 就是限流对象 + // bool 代表是否限流,true 就是要限流 + // err 限流器本身有咩有错误 + Limit(ctx context.Context, key string) (bool, error) +} diff --git a/middlewares/ratelimit/builder.go b/middlewares/ratelimit/builder.go new file mode 100644 index 0000000..39f515c --- /dev/null +++ b/middlewares/ratelimit/builder.go @@ -0,0 +1,55 @@ +package ratelimit + +import ( + "log" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + + "github.com/ecodeclub/ginx/internal/ratelimit" +) + +type Builder struct { + limiter ratelimit.Limiter + genKeyFn func(*gin.Context) string +} + +func NewBuilder(limiter ratelimit.Limiter) *Builder { + return &Builder{ + limiter: limiter, + genKeyFn: func(ctx *gin.Context) string { + var b strings.Builder + b.WriteString("ip-limiter") + b.WriteString(":") + b.WriteString(ctx.ClientIP()) + return b.String() + }, + } +} + +func (b *Builder) SetKey(fn func(*gin.Context) string) *Builder { + b.genKeyFn = fn + return b +} + +func (b *Builder) Build() gin.HandlerFunc { + return func(ctx *gin.Context) { + limited, err := b.limit(ctx) + if err != nil { + log.Println(err) + ctx.AbortWithStatus(http.StatusInternalServerError) + return + } + if limited { + log.Println(err) + ctx.AbortWithStatus(http.StatusTooManyRequests) + return + } + ctx.Next() + } +} + +func (b *Builder) limit(ctx *gin.Context) (bool, error) { + return b.limiter.Limit(ctx, b.genKeyFn(ctx)) +} diff --git a/middlewares/ratelimit/builder_test.go b/middlewares/ratelimit/builder_test.go new file mode 100644 index 0000000..03a0fb8 --- /dev/null +++ b/middlewares/ratelimit/builder_test.go @@ -0,0 +1,242 @@ +package ratelimit + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/ecodeclub/ginx/internal/ratelimit" + "github.com/ecodeclub/ginx/internal/ratelimit/mocks" +) + +func TestBuilder_SetKey(t *testing.T) { + tests := []struct { + name string + reqBuilder func(t *testing.T) *http.Request + fn func(*gin.Context) string + want string + }{ + { + name: "设置key成功", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "", nil) + if err != nil { + t.Fatal(err) + } + req.RemoteAddr = "127.0.0.1:80" + return req + }, + fn: func(ctx *gin.Context) string { + return "test" + }, + want: "test", + }, + { + name: "默认key", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "", nil) + if err != nil { + t.Fatal(err) + } + req.RemoteAddr = "127.0.0.1:80" + return req + }, + want: "ip-limiter:127.0.0.1", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := NewBuilder(nil) + if tt.fn != nil { + b.SetKey(tt.fn) + } + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := tt.reqBuilder(t) + ctx.Request = req + + assert.Equal(t, tt.want, b.genKeyFn(ctx)) + }) + } +} + +func TestBuilder_Build(t *testing.T) { + const limitURL = "/limit" + tests := []struct { + name string + + mock func(ctrl *gomock.Controller) ratelimit.Limiter + reqBuilder func(t *testing.T) *http.Request + + // 预期响应 + wantCode int + }{ + { + name: "不限流", + mock: func(ctrl *gomock.Controller) ratelimit.Limiter { + limiter := limitmocks.NewMockLimiter(ctrl) + limiter.EXPECT().Limit(gomock.Any(), gomock.Any()). + Return(false, nil) + return limiter + }, + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, limitURL, nil) + if err != nil { + t.Fatal(err) + } + return req + }, + wantCode: http.StatusOK, + }, + { + name: "限流", + mock: func(ctrl *gomock.Controller) ratelimit.Limiter { + limiter := limitmocks.NewMockLimiter(ctrl) + limiter.EXPECT().Limit(gomock.Any(), gomock.Any()). + Return(true, nil) + return limiter + }, + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, limitURL, nil) + if err != nil { + t.Fatal(err) + } + return req + }, + wantCode: http.StatusTooManyRequests, + }, + { + name: "系统错误", + mock: func(ctrl *gomock.Controller) ratelimit.Limiter { + limiter := limitmocks.NewMockLimiter(ctrl) + limiter.EXPECT().Limit(gomock.Any(), gomock.Any()). + Return(false, errors.New("模拟系统错误")) + return limiter + }, + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, limitURL, nil) + if err != nil { + t.Fatal(err) + } + return req + }, + wantCode: http.StatusInternalServerError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + svc := NewBuilder(tt.mock(ctrl)) + + server := gin.Default() + server.Use(svc.Build()) + svc.RegisterRoutes(server) + + req := tt.reqBuilder(t) + recorder := httptest.NewRecorder() + + server.ServeHTTP(recorder, req) + + assert.Equal(t, tt.wantCode, recorder.Code) + }) + } +} + +func TestBuilder_limit(t *testing.T) { + tests := []struct { + name string + + mock func(ctrl *gomock.Controller) ratelimit.Limiter + reqBuilder func(t *testing.T) *http.Request + + // 预期响应 + want bool + wantErr error + }{ + { + name: "不限流", + mock: func(ctrl *gomock.Controller) ratelimit.Limiter { + limiter := limitmocks.NewMockLimiter(ctrl) + limiter.EXPECT().Limit(gomock.Any(), gomock.Any()). + Return(false, nil) + return limiter + }, + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "", nil) + if err != nil { + t.Fatal(err) + } + req.RemoteAddr = "127.0.0.1:80" + return req + }, + want: false, + }, + { + name: "限流", + mock: func(ctrl *gomock.Controller) ratelimit.Limiter { + limiter := limitmocks.NewMockLimiter(ctrl) + limiter.EXPECT().Limit(gomock.Any(), gomock.Any()). + Return(true, nil) + return limiter + }, + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "", nil) + if err != nil { + t.Fatal(err) + } + req.RemoteAddr = "127.0.0.1:80" + return req + }, + want: true, + }, + { + name: "限流代码出错", + mock: func(ctrl *gomock.Controller) ratelimit.Limiter { + limiter := limitmocks.NewMockLimiter(ctrl) + limiter.EXPECT().Limit(gomock.Any(), gomock.Any()). + Return(false, errors.New("模拟系统错误")) + return limiter + }, + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "", nil) + if err != nil { + t.Fatal(err) + } + req.RemoteAddr = "127.0.0.1:80" + return req + }, + want: false, + wantErr: errors.New("模拟系统错误"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + limiter := tt.mock(ctrl) + b := NewBuilder(limiter) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := tt.reqBuilder(t) + ctx.Request = req + + got, err := b.limit(ctx) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func (b *Builder) RegisterRoutes(server *gin.Engine) { + server.GET("/limit", func(ctx *gin.Context) { + ctx.Status(http.StatusOK) + }) +} From e8ef0fe10124fcad1fdfa302f9840eda8727ab6a Mon Sep 17 00:00:00 2001 From: joil Date: Thu, 21 Sep 2023 22:09:18 +0800 Subject: [PATCH 2/2] =?UTF-8?q?feat:=20=E6=9A=B4=E9=9C=B2=20Limiter=20?= =?UTF-8?q?=E7=9A=84=E5=88=9D=E5=A7=8B=E5=8C=96=E5=87=BD=E6=95=B0=EF=BC=8C?= =?UTF-8?q?=E6=8F=90=E4=BE=9B=E8=87=AA=E5=AE=9A=E4=B9=89=E6=97=A5=E5=BF=97?= =?UTF-8?q?=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/integration/ratelimit_test.go | 3 +-- internal/ratelimit/redis_slide_window.go | 21 +++++----------- internal/ratelimit/redis_slide_window_test.go | 6 ++++- internal/ratelimit/types.go | 2 +- middlewares/ratelimit/builder.go | 22 +++++++++++++---- middlewares/ratelimit/builder_test.go | 4 ++-- middlewares/ratelimit/redis_slide_window.go | 24 +++++++++++++++++++ 7 files changed, 57 insertions(+), 25 deletions(-) create mode 100644 middlewares/ratelimit/redis_slide_window.go diff --git a/internal/integration/ratelimit_test.go b/internal/integration/ratelimit_test.go index c53f646..a68f286 100644 --- a/internal/integration/ratelimit_test.go +++ b/internal/integration/ratelimit_test.go @@ -14,7 +14,6 @@ import ( "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" - "github.com/ecodeclub/ginx/internal/ratelimit" limit "github.com/ecodeclub/ginx/middlewares/ratelimit" ) @@ -93,7 +92,7 @@ func initRedis() redis.Cmdable { func initWebServer(cmd redis.Cmdable) *gin.Engine { server := gin.Default() - limiter := ratelimit.NewRedisSlidingWindowLimiter(cmd, 500*time.Millisecond, 1) + limiter := limit.NewRedisSlidingWindowLimiter(cmd, 500*time.Millisecond, 1) server.Use(limit.NewBuilder(limiter).Build()) return server } diff --git a/internal/ratelimit/redis_slide_window.go b/internal/ratelimit/redis_slide_window.go index ff2396e..cd00441 100644 --- a/internal/ratelimit/redis_slide_window.go +++ b/internal/ratelimit/redis_slide_window.go @@ -13,26 +13,17 @@ var luaSlideWindow string // RedisSlidingWindowLimiter Redis 上的滑动窗口算法限流器实现 type RedisSlidingWindowLimiter struct { - cmd redis.Cmdable + Cmd redis.Cmdable // 窗口大小 - interval time.Duration + Interval time.Duration // 阈值 - rate int - // interval 内允许 rate 个请求 + Rate int + // Interval 内允许 Rate 个请求 // 1s 内允许 3000 个请求 } -func NewRedisSlidingWindowLimiter(cmd redis.Cmdable, - interval time.Duration, rate int) Limiter { - return &RedisSlidingWindowLimiter{ - cmd: cmd, - interval: interval, - rate: rate, - } -} - func (r *RedisSlidingWindowLimiter) Limit(ctx context.Context, key string) (bool, error) { - return r.cmd.Eval(ctx, luaSlideWindow, []string{key}, - r.interval.Milliseconds(), r.rate, time.Now().UnixMilli()).Bool() + return r.Cmd.Eval(ctx, luaSlideWindow, []string{key}, + r.Interval.Milliseconds(), r.Rate, time.Now().UnixMilli()).Bool() } diff --git a/internal/ratelimit/redis_slide_window_test.go b/internal/ratelimit/redis_slide_window_test.go index 2b4afde..1c4fb98 100644 --- a/internal/ratelimit/redis_slide_window_test.go +++ b/internal/ratelimit/redis_slide_window_test.go @@ -10,7 +10,11 @@ import ( ) func TestRedisSlidingWindowLimiter_Limit(t *testing.T) { - r := NewRedisSlidingWindowLimiter(initRedis(), 500*time.Millisecond, 1) + r := &RedisSlidingWindowLimiter{ + Cmd: initRedis(), + Interval: 500 * time.Millisecond, + Rate: 1, + } tests := []struct { name string ctx context.Context diff --git a/internal/ratelimit/types.go b/internal/ratelimit/types.go index 8a43df5..8713bba 100644 --- a/internal/ratelimit/types.go +++ b/internal/ratelimit/types.go @@ -5,6 +5,6 @@ import "context" type Limiter interface { // Limit 有没有触发限流。key 就是限流对象 // bool 代表是否限流,true 就是要限流 - // err 限流器本身有咩有错误 + // err 限流器本身有没有错误 Limit(ctx context.Context, key string) (bool, error) } diff --git a/middlewares/ratelimit/builder.go b/middlewares/ratelimit/builder.go index 39f515c..3707016 100644 --- a/middlewares/ratelimit/builder.go +++ b/middlewares/ratelimit/builder.go @@ -12,9 +12,13 @@ import ( type Builder struct { limiter ratelimit.Limiter - genKeyFn func(*gin.Context) string + genKeyFn func(ctx *gin.Context) string + logFn func(msg any, args ...any) } +// NewBuilder +// genKeyFn: 默认使用 IP 限流. +// logFn: 默认使用 log.Println(). func NewBuilder(limiter ratelimit.Limiter) *Builder { return &Builder{ limiter: limiter, @@ -25,24 +29,34 @@ func NewBuilder(limiter ratelimit.Limiter) *Builder { b.WriteString(ctx.ClientIP()) return b.String() }, + logFn: func(msg any, args ...any) { + v := make([]any, 0, len(args)+1) + v = append(v, msg) + v = append(v, args...) + log.Println(v...) + }, } } -func (b *Builder) SetKey(fn func(*gin.Context) string) *Builder { +func (b *Builder) SetKeyGenFunc(fn func(*gin.Context) string) *Builder { b.genKeyFn = fn return b } +func (b *Builder) SetLogFunc(fn func(msg any, args ...any)) *Builder { + b.logFn = fn + return b +} + func (b *Builder) Build() gin.HandlerFunc { return func(ctx *gin.Context) { limited, err := b.limit(ctx) if err != nil { - log.Println(err) + b.logFn(err) ctx.AbortWithStatus(http.StatusInternalServerError) return } if limited { - log.Println(err) ctx.AbortWithStatus(http.StatusTooManyRequests) return } diff --git a/middlewares/ratelimit/builder_test.go b/middlewares/ratelimit/builder_test.go index 03a0fb8..471494f 100644 --- a/middlewares/ratelimit/builder_test.go +++ b/middlewares/ratelimit/builder_test.go @@ -14,7 +14,7 @@ import ( "github.com/ecodeclub/ginx/internal/ratelimit/mocks" ) -func TestBuilder_SetKey(t *testing.T) { +func TestBuilder_SetKeyGenFunc(t *testing.T) { tests := []struct { name string reqBuilder func(t *testing.T) *http.Request @@ -53,7 +53,7 @@ func TestBuilder_SetKey(t *testing.T) { t.Run(tt.name, func(t *testing.T) { b := NewBuilder(nil) if tt.fn != nil { - b.SetKey(tt.fn) + b.SetKeyGenFunc(tt.fn) } recorder := httptest.NewRecorder() diff --git a/middlewares/ratelimit/redis_slide_window.go b/middlewares/ratelimit/redis_slide_window.go new file mode 100644 index 0000000..b90f488 --- /dev/null +++ b/middlewares/ratelimit/redis_slide_window.go @@ -0,0 +1,24 @@ +package ratelimit + +import ( + "time" + + "github.com/redis/go-redis/v9" + + "github.com/ecodeclub/ginx/internal/ratelimit" +) + +// NewRedisSlidingWindowLimiter 创建一个基于 redis 的滑动窗口限流器. +// cmd: 可传入 redis 的客户端 +// interval: 窗口大小 +// rate: 阈值 +// 表示: 在 interval 内允许 rate 个请求 +// 示例: 1s 内允许 3000 个请求 +func NewRedisSlidingWindowLimiter(cmd redis.Cmdable, + interval time.Duration, rate int) ratelimit.Limiter { + return &ratelimit.RedisSlidingWindowLimiter{ + Cmd: cmd, + Interval: interval, + Rate: rate, + } +}