Skip to content

Commit

Permalink
add forward from domain (#2323)
Browse files Browse the repository at this point in the history
* add forward from domain

* add balancer forward

* add unittest and readme

* add short description new feature

* add short description on signature

* golangci-lint fix

---------

Co-authored-by: René Werner <rene@gofiber.io>
  • Loading branch information
ryanbekhen and ReneWerner87 authored Feb 3, 2023
1 parent 028d821 commit 61a3336
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 74 deletions.
165 changes: 91 additions & 74 deletions middleware/proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,16 @@ Proxy middleware for [Fiber](https://github.com/gofiber/fiber) that allows you t
### Signatures

```go
// Balancer create a load balancer among multiple upstrem servers.
func Balancer(config Config) fiber.Handler
// Forward performs the given http request and fills the given http response.
func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler
// Do performs the given http request and fills the given http response.
func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error
// DomainForward the given http request based on the given domain and fills the given http response
func DomainForward(hostname string, addr string, clients ...*fasthttp.Client) fiber.Handler
// BalancerForward performs the given http request based round robin balancer and fills the given http response
func BalancerForward(servers []string, clients ...*fasthttp.Client) fiber.Handler
```

### Examples
Expand All @@ -23,8 +30,8 @@ Import the middleware package that is part of the Fiber web framework

```go
import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/proxy"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/proxy"
)
```

Expand All @@ -39,54 +46,64 @@ proxy.WithTlsConfig(&tls.Config{

// if you need to use global self-custom client, you should use proxy.WithClient.
proxy.WithClient(&fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
})

// Forward to url
app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif"))

// If you want to forward with a specific domain. You have to use proxy.DomainForward.
app.Get("/payments", proxy.DomainForward("docs.gofiber.io", "http://localhost:8000"))

// Forward to url with local custom client
app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif", &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
}))

// Make request within handler
app.Get("/:id", func(c *fiber.Ctx) error {
url := "https://i.imgur.com/"+c.Params("id")+".gif"
if err := proxy.Do(c, url); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
url := "https://i.imgur.com/"+c.Params("id")+".gif"
if err := proxy.Do(c, url); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})

// Minimal round robin balancer
app.Use(proxy.Balancer(proxy.Config{
Servers: []string{
"http://localhost:3001",
"http://localhost:3002",
"http://localhost:3003",
},
Servers: []string{
"http://localhost:3001",
"http://localhost:3002",
"http://localhost:3003",
},
}))

// Or extend your balancer for customization
app.Use(proxy.Balancer(proxy.Config{
Servers: []string{
"http://localhost:3001",
"http://localhost:3002",
"http://localhost:3003",
},
ModifyRequest: func(c *fiber.Ctx) error {
c.Request().Header.Add("X-Real-IP", c.IP())
return nil
},
ModifyResponse: func(c *fiber.Ctx) error {
c.Response().Header.Del(fiber.HeaderServer)
return nil
},
Servers: []string{
"http://localhost:3001",
"http://localhost:3002",
"http://localhost:3003",
},
ModifyRequest: func(c *fiber.Ctx) error {
c.Request().Header.Add("X-Real-IP", c.IP())
return nil
},
ModifyResponse: func(c *fiber.Ctx) error {
c.Response().Header.Del(fiber.HeaderServer)
return nil
},
}))

// Or this way if the balancer is using https and the destination server is only using http.
app.Use(proxy.BalancerForward([]string{
"http://localhost:3001",
"http://localhost:3002",
"http://localhost:3003",
}))
```

Expand All @@ -95,50 +112,50 @@ app.Use(proxy.Balancer(proxy.Config{
```go
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool

// Servers defines a list of <scheme>://<host> HTTP servers,
//
// which are used in a round-robin manner.
// i.e.: "https://foobar.com, http://www.foobar.com"
//
// Required
Servers []string

// ModifyRequest allows you to alter the request
//
// Optional. Default: nil
ModifyRequest fiber.Handler

// ModifyResponse allows you to alter the response
//
// Optional. Default: nil
ModifyResponse fiber.Handler
// Timeout is the request timeout used when calling the proxy client
//
// Optional. Default: 1 second
Timeout time.Duration

// Per-connection buffer size for requests' reading.
// This also limits the maximum header size.
// Increase this buffer if your clients send multi-KB RequestURIs
// and/or multi-KB headers (for example, BIG cookies).
ReadBufferSize int
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool

// Servers defines a list of <scheme>://<host> HTTP servers,
//
// which are used in a round-robin manner.
// i.e.: "https://foobar.com, http://www.foobar.com"
//
// Required
Servers []string

// ModifyRequest allows you to alter the request
//
// Optional. Default: nil
ModifyRequest fiber.Handler

// ModifyResponse allows you to alter the response
//
// Optional. Default: nil
ModifyResponse fiber.Handler
// Timeout is the request timeout used when calling the proxy client
//
// Optional. Default: 1 second
Timeout time.Duration

// Per-connection buffer size for requests' reading.
// This also limits the maximum header size.
// Increase this buffer if your clients send multi-KB RequestURIs
// and/or multi-KB headers (for example, BIG cookies).
ReadBufferSize int

// Per-connection buffer size for responses' writing.
WriteBufferSize int

// tls config for the http client.
TlsConfig *tls.Config
// Client is custom client when client config is complex.
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
// will not be used if the client are set.
Client *fasthttp.LBClient
// Per-connection buffer size for responses' writing.
WriteBufferSize int

// tls config for the http client.
TlsConfig *tls.Config
// Client is custom client when client config is complex.
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
// will not be used if the client are set.
Client *fasthttp.LBClient
}
```

Expand Down
50 changes: 50 additions & 0 deletions middleware/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,53 @@ func getScheme(uri []byte) []byte {
}
return uri[:i-1]
}

// DomainForward performs an http request based on the given domain and populates the given http response.
// This method will return an fiber.Handler
func DomainForward(hostname, addr string, clients ...*fasthttp.Client) fiber.Handler {
return func(c *fiber.Ctx) error {
host := string(c.Request().Host())
if host == hostname {
return Do(c, addr+c.OriginalURL(), clients...)
}
return nil
}
}

type roundrobin struct {
sync.Mutex

current int
pool []string
}

// this method will return a string of addr server from list server.
func (r *roundrobin) get() string {
r.Lock()
defer r.Unlock()

if r.current >= len(r.pool) {
r.current %= len(r.pool)
}

result := r.pool[r.current]
r.current++
return result
}

// BalancerForward Forward performs the given http request with round robin algorithm to server and fills the given http response.
// This method will return an fiber.Handler
func BalancerForward(servers []string, clients ...*fasthttp.Client) fiber.Handler {
r := &roundrobin{
current: 0,
pool: servers,
}
return func(c *fiber.Ctx) error {
server := r.get()
if !strings.HasPrefix(server, "http") {
server = "http://" + server
}
c.Request().Header.Add("X-Real-IP", c.IP())
return Do(c, server+c.OriginalURL(), clients...)
}
}
57 changes: 57 additions & 0 deletions middleware/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,3 +473,60 @@ func Test_ProxyBalancer_Custom_Client(t *testing.T) {
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
}

// go test -run Test_Proxy_Domain_Forward_Local
func Test_Proxy_Domain_Forward_Local(t *testing.T) {
t.Parallel()
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
app := fiber.New(fiber.Config{DisableStartupMessage: true})

// target server
ln1, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
app1 := fiber.New(fiber.Config{DisableStartupMessage: true})

app1.Get("/test", func(c *fiber.Ctx) error {
return c.SendString("test_local_client:" + c.Query("query_test"))
})

proxyAddr := ln.Addr().String()
targetAddr := ln1.Addr().String()
localDomain := strings.Replace(proxyAddr, "127.0.0.1", "localhost", 1)
app.Use(DomainForward(localDomain, "http://"+targetAddr, &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,

Dial: fasthttp.Dial,
}))

go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
go func() { utils.AssertEqual(t, nil, app1.Listener(ln1)) }()

code, body, errs := fiber.Get("http://" + localDomain + "/test?query_test=true").String()
utils.AssertEqual(t, 0, len(errs))
utils.AssertEqual(t, fiber.StatusOK, code)
utils.AssertEqual(t, "test_local_client:true", body)
}

// go test -run Test_Proxy_Balancer_Forward_Local
func Test_Proxy_Balancer_Forward_Local(t *testing.T) {
t.Parallel()

app := fiber.New()

_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("forwarded")
})

app.Use(BalancerForward([]string{addr}))

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)

b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)

utils.AssertEqual(t, string(b), "forwarded")
}

0 comments on commit 61a3336

Please sign in to comment.