diff --git a/README.md b/README.md index 068f453..85b1a7d 100644 --- a/README.md +++ b/README.md @@ -34,19 +34,19 @@ WSP server configuration --- host : 127.0.0.1 # Address to bind the HTTP server port : 8080 # Port to bind the HTTP server -timeout : 1000 # Time to wait before acquiring a WS connection to forward the request (milliseconds) -idletimeout : 60000 # Time to wait before closing idle connection when there is enough idle connections (milliseconds) +timeout : 1s # Time to wait before acquiring a WS connection to forward the request (milliseconds) +idletimeout : 60s # Time to wait before closing idle connection when there is enough idle connections (milliseconds) #blacklist : # Forbidden destination ( deny nothing if empty ) # - method : ".*" # Applied in order before whitelist # url : "^http(s)?://google.*" # None must match # headers : # Optinal header check -# X-CUSTOM-HEADER : "^value$" # +# X-CUSTOM-HEADER : "^value$" # #whitelist : # Allowed destinations ( allow all if empty ) # - method : "^GET$" # Applied in order after blacklist # url : "^http(s)?://.*$" # One must match # headers : # Optinal header check -# X-CUSTOM-HEADER : "^value$" # -# secretkey : ThisIsASecret # secret key that must be set in clients configuration +# X-CUSTOM-HEADER : "^value$" # +#secretkey : ThisIsASecret # shared secret key that must match the value set in clients configuration ``` ``` @@ -83,23 +83,20 @@ targets : # Endpoints to connect to - ws://127.0.0.1:8080/register # poolidlesize : 10 # Default number of concurrent open (TCP) connections to keep idle per WSP server poolmaxsize : 100 # Maximum number of concurrent open (TCP) connections per WSP server +#insecureSkipVerify : true # Disable the http client certificate chain and hostname verification #blacklist : # Forbidden destination ( deny nothing if empty ) # - method : ".*" # Applied in order before whitelist # url : ".*forbidden.*" # None must match # headers : # Optinal header check -# X-CUSTOM-HEADER : "^value$" # +# X-CUSTOM-HEADER : "^value$" # #whitelist : # Allowed destinations ( allow all if empty ) # - method : "^GET$" # Applied in order after blacklist # url : "http(s)?://.*$" # One must match # headers : # Optinal header check -# X-CUSTOM-HEADER : "^value$" # -# secretkey : ThisIsASecret # secret key that must match the value set in servers configuration +# X-CUSTOM-HEADER : "^value$" # +# secretkey : ThisIsASecret # shared secret key that must match the value set in servers configuration ``` - - poolMinSize is the default number of opened TCP/HTTP/WS connections - to open per WSP server. If there is a burst of simpultaneous requests - the number of open connection will rise and then decrease back to this - number. - poolMinIdleSize is the number of connection to keep idle, meaning that if there is more than this number of simultaneous requests the WSP client will try to open more connections to keep idle connection. diff --git a/clean.sh b/clean.sh new file mode 100755 index 0000000..23b2139 --- /dev/null +++ b/clean.sh @@ -0,0 +1 @@ +for i in $(ps faux | grep wsp_ | grep -v grep | awk '{ print $2 }') ; do kill -9 $i ; done diff --git a/client/client.go b/client/client.go index 9f5d0f6..5b0b9e7 100644 --- a/client/client.go +++ b/client/client.go @@ -1,43 +1,66 @@ package client import ( + "crypto/tls" "net/http" "github.com/gorilla/websocket" + "github.com/root-gg/wsp/common" + "log" ) -// Client connects to one or more Server using HTTP websockets +// Client connects to one or more Server using HTTP WebSocket // The Server can then send HTTP requests to execute type Client struct { - Config *Config + Config *Config + validator *common.RequestValidator - client *http.Client - dialer *websocket.Dialer - pools map[string]*Pool + httpClient *http.Client + dialer *websocket.Dialer + pools map[string]*Pool } // NewClient creates a new Proxy -func NewClient(config *Config) (c *Client) { - c = new(Client) - c.Config = config - c.client = &http.Client{} - c.dialer = &websocket.Dialer{} - c.pools = make(map[string]*Pool) +func NewClient(config *Config) (client *Client) { + client = new(Client) + client.Config = config + + client.validator = &common.RequestValidator{ + Whitelist: config.Whitelist, + Blacklist: config.Blacklist, + } + err := client.validator.Initialize() + if err != nil { + log.Fatalf("Unable to initialize the request validator : %s", err) + } + + // WebSocket tcp dialer to connect to the remote WSP servers + client.dialer = &websocket.Dialer{} + + // HTTP client to execute HTTP requests received by the WebSocket tunnels + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: client.Config.InsecureSkipVerify}, + } + client.httpClient = &http.Client{Transport: tr} + + client.pools = make(map[string]*Pool) + return } // Start the Proxy -func (c *Client) Start() { - for _, target := range c.Config.Targets { - pool := NewPool(c, target, c.Config.SecretKey) - c.pools[target] = pool - go pool.Start() +func (client *Client) Start() { + for _, target := range client.Config.Targets { + pool := NewPool(client, target) + client.pools[target] = pool + pool.start() } } // Shutdown the Proxy -func (c *Client) Shutdown() { - for _, pool := range c.pools { - pool.Shutdown() +func (client *Client) Shutdown() { + for _, pool := range client.pools { + pool.close() } + } diff --git a/client/config.go b/client/config.go index 92c62da..6081d7d 100644 --- a/client/config.go +++ b/client/config.go @@ -2,6 +2,7 @@ package client import ( "io/ioutil" + "os" "github.com/nu7hatch/gouuid" "gopkg.in/yaml.v2" @@ -11,13 +12,15 @@ import ( // Config configures an Proxy type Config struct { - ID string - Targets []string - PoolIdleSize int - PoolMaxSize int - Whitelist []*common.Rule - Blacklist []*common.Rule - SecretKey string + Name string + ID string `json:"-"` + Targets []string + PoolIdleSize int + PoolMaxSize int + Whitelist []*common.Rule + Blacklist []*common.Rule + SecretKey string + InsecureSkipVerify bool } // NewConfig creates a new ProxyConfig @@ -30,13 +33,16 @@ func NewConfig() (config *Config) { } config.ID = id.String() + hostname, err := os.Hostname() + if err != nil { + panic(err) + } + config.Name = hostname + config.Targets = []string{"ws://127.0.0.1:8080/register"} config.PoolIdleSize = 10 config.PoolMaxSize = 100 - config.Whitelist = make([]*common.Rule, 0) - config.Blacklist = make([]*common.Rule, 0) - return } @@ -54,19 +60,5 @@ func LoadConfiguration(path string) (config *Config, err error) { return } - // Compile the rules - - for _, rule := range config.Whitelist { - if err = rule.Compile(); err != nil { - return - } - } - - for _, rule := range config.Blacklist { - if err = rule.Compile(); err != nil { - return - } - } - return } diff --git a/client/connection.go b/client/connection.go index 49d70ed..79c8763 100644 --- a/client/connection.go +++ b/client/connection.go @@ -8,6 +8,8 @@ import ( "io/ioutil" "log" "net/http" + "sync" + "sync/atomic" "time" "github.com/gorilla/websocket" @@ -15,149 +17,170 @@ import ( "github.com/root-gg/wsp/common" ) -// Status of a Connection +var id uint64 = 0 + +func getNextId() uint64 { + return atomic.AddUint64(&id, uint64(1)) +} + +// ConnectionStatus of a Connection +type ConnectionStatus int + const ( - CONNECTING = iota + CONNECTING ConnectionStatus = iota IDLE RUNNING + CLOSED ) // Connection handle a single websocket (HTTP/TCP) connection to an Server type Connection struct { - pool *Pool - ws *websocket.Conn - status int + clientSettings *common.ClientSettings + + ws *websocket.Conn + + status ConnectionStatus + connectionStatusListner *ConnectionStatusListner + + lock sync.RWMutex + done chan struct{} } -// NewConnection create a Connection object -func NewConnection(pool *Pool) (conn *Connection) { +// newConnection create a Connection object +func newConnection(clientSettings *common.ClientSettings, connectionStatusListner *ConnectionStatusListner) (conn *Connection) { conn = new(Connection) - conn.pool = pool + conn.clientSettings = clientSettings + conn.clientSettings.ConnectionId = getNextId() + conn.connectionStatusListner = connectionStatusListner conn.status = CONNECTING + conn.done = make(chan struct{}) return } -// Connect to the IsolatorServer using a HTTP websocket -func (connection *Connection) Connect() (err error) { - log.Printf("Connecting to %s", connection.pool.target) +// Set the status of the connection in a concurrently safe way +func (conn *Connection) setStatus(status ConnectionStatus) { + conn.lock.Lock() + defer conn.lock.Unlock() - // Create a new TCP(/TLS) connection ( no use of net.http ) - connection.ws, _, err = connection.pool.client.dialer.Dial(connection.pool.target, http.Header{"X-SECRET-KEY": {connection.pool.secretKey}}) + // Trigger a pool refresh to open new connections if needed + defer conn.connectionStatusListner.onConnectionStatusChanged() + conn.status = status +} + +// Get the status of the connection in a concurrently safe way +func (conn *Connection) getStatus() ConnectionStatus { + conn.lock.RLock() + defer conn.lock.RUnlock() + return conn.status +} +// Open a connection to the remote WSP Server +func (conn *Connection) connect(dialer *websocket.Dialer, target string, secretKey string) (err error) { + conn.lock.Lock() + defer conn.lock.Unlock() + + log.Printf("Connecting to %s", target) + + // Create a new TCP(/TLS) conn ( no use of net.http ) + conn.ws, _, err = dialer.Dial(target, http.Header{"X-SECRET-KEY": {secretKey}}) if err != nil { - return err + return fmt.Errorf("dialer error : %s", err) } - log.Printf("Connected to %s", connection.pool.target) + log.Printf("Connected to %s", target) - // Send the greeting message with proxy id and wanted pool size. - greeting := fmt.Sprintf("%s_%d", connection.pool.client.Config.ID, connection.pool.client.Config.PoolIdleSize) - err = connection.ws.WriteMessage(websocket.TextMessage, []byte(greeting)) + return nil +} + +// Send the client configuration to the remote WSP Server +func (conn *Connection) initialize() (err error) { + message, err := conn.clientSettings.ToJson() if err != nil { - log.Println("greeting error :", err) - connection.Close() - return + return fmt.Errorf("connection initlization error, unable to serialize client settings : %s", err) } - go connection.serve() + // Send the greeting message with proxy id and wanted pool size. + err = conn.ws.WriteMessage(websocket.TextMessage, message) + if err != nil { + return fmt.Errorf("connection initlization error : %s", err) + } - return + return nil } -// the main loop it : +// the main loop : // - wait to receive HTTP requests from the Server // - execute HTTP requests // - send HTTP response back to the Server // // As in the server code there is no buffering of HTTP request/response body // As is the server if any error occurs the connection is closed/throwed -func (connection *Connection) serve() { - defer connection.Close() - - // Keep connection alive +func (conn *Connection) serve(httpClient *http.Client, validator *common.RequestValidator) { + // Keep conn alive go func() { + defer conn.close() + for { - time.Sleep(30 * time.Second) - err := connection.ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(time.Second)) + if conn.isClosed() { + break + } + time.Sleep(5 * time.Second) + err := conn.ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(time.Second)) if err != nil { - connection.Close() + //log.Printf("ping fail : %s", err) + break } } }() for { + conn.setStatus(IDLE) + // Read request - connection.status = IDLE - _, jsonRequest, err := connection.ws.ReadMessage() + _, jsonRequest, err := conn.ws.ReadMessage() if err != nil { - log.Println("Unable to read request", err) + if !conn.isClosed() { + log.Printf("Unable to read request :%s", err) + } break } - connection.status = RUNNING - - // Trigger a pool refresh to open new connections if needed - go connection.pool.connector() + conn.setStatus(RUNNING) // Deserialize request httpRequest := new(common.HTTPRequest) err = json.Unmarshal(jsonRequest, httpRequest) if err != nil { - connection.error(fmt.Sprintf("Unable to deserialize json http request : %s\n", err)) + conn.error(fmt.Sprintf("Unable to deserialize json http request : %s\n", err)) break } - req, err := common.UnserializeHTTPRequest(httpRequest) + // Get an executable net/http.Request + req, err := httpRequest.ToStdLibHTTPRequest() if err != nil { - connection.error(fmt.Sprintf("Unable to deserialize http request : %v\n", err)) + conn.error(fmt.Sprintf("Unable to deserialize http request : %v\n", err)) break } - log.Printf("[%s] %s", req.Method, req.URL.String()) - - // Apply blacklist - if len(connection.pool.client.Config.Blacklist) > 0 { - for _, rule := range connection.pool.client.Config.Blacklist { - if rule.Match(req) { - // Discard request body - err = connection.discard() - if err != nil { - break - } - err = connection.error("Destination is forbidden") - if err != nil { - break - } - continue - } - } - } - - // Apply whitelist - if len(connection.pool.client.Config.Whitelist) > 0 { - allowed := false - for _, rule := range connection.pool.client.Config.Whitelist { - if rule.Match(req) { - allowed = true - break - } + err = validator.Validate(req) + if err != nil { + // Discard the request body + err2 := conn.discard() + if err2 != nil { + conn.error(err2.Error()) + break } - if !allowed { - // Discard request body - err = connection.discard() - if err != nil { - break - } - err = connection.error("Destination is not allowed\n") - if err != nil { - break - } - continue + err3 := conn.error(err.Error()) + if err3 != nil { + break } + continue } + log.Printf("[%s] %s", req.Method, req.URL.String()) + // Pipe request body - _, bodyReader, err := connection.ws.NextReader() + _, bodyReader, err := conn.ws.NextReader() if err != nil { log.Printf("Unable to get response body reader : %v", err) break @@ -165,9 +188,9 @@ func (connection *Connection) serve() { req.Body = ioutil.NopCloser(bodyReader) // Execute request - resp, err := connection.pool.client.client.Do(req) + resp, err := httpClient.Do(req) if err != nil { - err = connection.error(fmt.Sprintf("Unable to execute request : %v\n", err)) + err = conn.error(fmt.Sprintf("Unable to execute request : %v\n", err)) if err != nil { break } @@ -177,7 +200,7 @@ func (connection *Connection) serve() { // Serialize response jsonResponse, err := json.Marshal(common.SerializeHTTPResponse(resp)) if err != nil { - err = connection.error(fmt.Sprintf("Unable to serialize response : %v\n", err)) + err = conn.error(fmt.Sprintf("Unable to serialize response : %v\n", err)) if err != nil { break } @@ -185,14 +208,14 @@ func (connection *Connection) serve() { } // Write response - err = connection.ws.WriteMessage(websocket.TextMessage, jsonResponse) + err = conn.ws.WriteMessage(websocket.TextMessage, jsonResponse) if err != nil { log.Printf("Unable to write response : %v", err) break } // Pipe response body - bodyWriter, err := connection.ws.NextWriter(websocket.BinaryMessage) + bodyWriter, err := conn.ws.NextWriter(websocket.BinaryMessage) if err != nil { log.Printf("Unable to get response body writer : %v", err) break @@ -206,7 +229,8 @@ func (connection *Connection) serve() { } } -func (connection *Connection) error(msg string) (err error) { +// Craft an error response to forward back to the WSP Server +func (conn *Connection) error(msg string) (err error) { resp := common.NewHTTPResponse() resp.StatusCode = 527 @@ -222,14 +246,14 @@ func (connection *Connection) error(msg string) (err error) { } // Write response - err = connection.ws.WriteMessage(websocket.TextMessage, jsonResponse) + err = conn.ws.WriteMessage(websocket.TextMessage, jsonResponse) if err != nil { log.Printf("Unable to write response : %v", err) return } // Write response body - err = connection.ws.WriteMessage(websocket.BinaryMessage, []byte(msg)) + err = conn.ws.WriteMessage(websocket.BinaryMessage, []byte(msg)) if err != nil { log.Printf("Unable to write response body : %v", err) return @@ -238,23 +262,40 @@ func (connection *Connection) error(msg string) (err error) { return } -// Discard request body -func (connection *Connection) discard() (err error) { - mt, _, err := connection.ws.NextReader() +// Discard the next message +func (conn *Connection) discard() (err error) { + mt, _, err := conn.ws.NextReader() if err != nil { - return nil + return err } if mt != websocket.BinaryMessage { return errors.New("Invalid body message type") } - return + return nil +} + +// IsClosed return true if the connection has been closed +func (conn *Connection) isClosed() bool { + select { + case <-conn.done: + return true + default: + return false + } } -// Close close the ws/tcp connection and remove it from the pool -func (connection *Connection) Close() { - connection.pool.lock.Lock() - defer connection.pool.lock.Unlock() +// Close the ws/tcp connection and remove it from the pool +func (conn *Connection) close() { + conn.lock.Lock() + defer conn.lock.Unlock() + if conn.isClosed() { + return + } + + conn.status = CLOSED + close(conn.done) - connection.pool.remove(connection) - connection.ws.Close() + if conn.ws != nil { + conn.ws.Close() + } } diff --git a/client/pool.go b/client/pool.go index ccf268a..5dc87df 100644 --- a/client/pool.go +++ b/client/pool.go @@ -2,6 +2,7 @@ package client import ( "fmt" + "github.com/root-gg/wsp/common" "log" "sync" "time" @@ -9,31 +10,40 @@ import ( // Pool manage a pool of connection to a remote Server type Pool struct { - client *Client - target string - secretKey string + client *Client + target string connections []*Connection lock sync.RWMutex - done chan struct{} + deadline *time.Time + + connectionStatusListner *ConnectionStatusListner + done chan struct{} } // NewPool creates a new Pool -func NewPool(client *Client, target string, secretKey string) (pool *Pool) { +func NewPool(client *Client, target string) (pool *Pool) { pool = new(Pool) pool.client = client pool.target = target pool.connections = make([]*Connection, 0) - pool.secretKey = secretKey + pool.connectionStatusListner = NewConnectionStatusListner() pool.done = make(chan struct{}) return } // Start connect to the remote Server -func (pool *Pool) Start() { - pool.connector() +func (pool *Pool) start() { + + // Try to open new connections to reach the desired pool size as fast as possible + // Normally the pool is filled right away by the go conn.pool.connector() + // triggered when a connection is about to be used but te ticker is here to speed things up if needed + go func() { + // Bootstrap + pool.connector() + ticker := time.Tick(time.Second) for { select { @@ -41,77 +51,121 @@ func (pool *Pool) Start() { break case <-ticker: pool.connector() + case <-pool.connectionStatusListner.connectionStatusChanged(): + pool.connector() } } }() } -// The garbage collector +// The connector is responsible to open the connections to the remote WSP Server +// It tries to keep Config.PoolIdleSize connecting/idle connections besides the running +// ones and will take care to never exceed Config.PoolMaxSize open connections func (pool *Pool) connector() { pool.lock.Lock() defer pool.lock.Unlock() - poolSize := pool.Size() + if pool.isClosed() { + return + } + + // Remove closed connections + pool.clean() + + poolSize := pool.size() + log.Printf("%s pool size : %v", pool.target, poolSize) - //log.Printf("%s pool size : %v", pool.target, poolSize) + // Number of missing idle connection to reach the ideal pool size + missing := pool.client.Config.PoolIdleSize - poolSize.idle - // Create enough connection to fill the pool - toCreate := pool.client.Config.PoolIdleSize - poolSize.idle + // If the pool is empty only try to create one single connection + isEmpty := poolSize.idle+poolSize.running == 0 + if isEmpty { + missing = 1 - // Create only one connection if the pool is empty - if poolSize.total == 0 { - toCreate = 1 + //// ratelimit connection + //if pool.deadline != nil { + // time.Sleep(pool.deadline.Sub(time.Now())) + //} + //deadline := time.Now().Add(1000 * time.Millisecond) + //pool.deadline = &deadline } // Ensure to open at most PoolMaxSize connections - if poolSize.total+toCreate > pool.client.Config.PoolMaxSize { - toCreate = pool.client.Config.PoolMaxSize - poolSize.total + if poolSize.total+missing > pool.client.Config.PoolMaxSize { + missing = pool.client.Config.PoolMaxSize - poolSize.total } - //log.Printf("%v",toCreate) + // Remove already in-flight connections + toCreate := missing - poolSize.connecting // Try to reach ideal pool size for i := 0; i < toCreate; i++ { - conn := NewConnection(pool) + clientSettings := &common.ClientSettings{ + ID: pool.client.Config.ID, + Name: pool.client.Config.Name, + PoolSize: pool.client.Config.PoolIdleSize, + } + conn := newConnection(clientSettings, pool.connectionStatusListner) + + // Append connection to the pool before trying to connect + // so in-flight connection can appear in poolSize + // Anyway nobody will ever get a connection from this pool, + // this is the only way to add a connection to the pool and + // the only way to remove a connection from the pool is + // pool.clean() which is called the at the beginning of this method. pool.connections = append(pool.connections, conn) go func() { - err := conn.Connect() + defer conn.close() + + err := conn.connect(pool.client.dialer, pool.target, pool.client.Config.SecretKey) if err != nil { - log.Printf("Unable to connect to %s : %s", pool.target, err) + log.Printf("Unable to establish connection %d to %s : %s", conn.clientSettings.ConnectionId, pool.target, err) + return + } - pool.lock.Lock() - defer pool.lock.Unlock() - pool.remove(conn) + err = conn.initialize() + if err != nil { + log.Printf("Unable to connection %d to %s: %s", conn.clientSettings.ConnectionId, pool.target, err) + return } + + // This call blocks + conn.serve(pool.client.httpClient, pool.client.validator) }() } } -// Add a connection to the pool -func (pool *Pool) add(conn *Connection) { - pool.connections = append(pool.connections, conn) -} - -// Remove a connection from the pool -func (pool *Pool) remove(conn *Connection) { - // This trick uses the fact that a slice shares the same backing array and capacity as the original, - // so the storage is reused for the filtered slice. Of course, the original contents are modified. - - var filtered []*Connection // == nil - for _, c := range pool.connections { - if conn != c { - filtered = append(filtered, c) +// Remove closed connections from the pool +func (pool *Pool) clean() { + var filtered []*Connection + for _, conn := range pool.connections { + if conn.getStatus() != CLOSED { + filtered = append(filtered, conn) } } pool.connections = filtered } -// Shutdown close all connection in the pool -func (pool *Pool) Shutdown() { +// Check if the pool has been closed +func (pool *Pool) isClosed() bool { + select { + case <-pool.done: + return true + default: + return false + } +} + +// Close all connection in the pool and be sure we don't use it again +func (pool *Pool) close() { + pool.lock.Lock() + defer pool.lock.Unlock() + close(pool.done) for _, conn := range pool.connections { - conn.Close() + conn.close() } } @@ -120,25 +174,28 @@ type PoolSize struct { connecting int idle int running int + closed int total int } func (poolSize *PoolSize) String() string { - return fmt.Sprintf("Connecting %d, idle %d, running %d, total %d", poolSize.connecting, poolSize.idle, poolSize.running, poolSize.total) + return fmt.Sprintf("Connecting %d, idle %d, running %d, closed %d, total %d", poolSize.connecting, poolSize.idle, poolSize.running, poolSize.closed, poolSize.total) } // Size return the current state of the pool -func (pool *Pool) Size() (poolSize *PoolSize) { +func (pool *Pool) size() (poolSize *PoolSize) { poolSize = new(PoolSize) poolSize.total = len(pool.connections) for _, connection := range pool.connections { - switch connection.status { + switch connection.getStatus() { case CONNECTING: poolSize.connecting++ case IDLE: poolSize.idle++ case RUNNING: poolSize.running++ + case CLOSED: + poolSize.closed++ } } diff --git a/client/status_listener.go b/client/status_listener.go new file mode 100644 index 0000000..ed1cc2e --- /dev/null +++ b/client/status_listener.go @@ -0,0 +1,29 @@ +package client + +// This listener is used to trigger the connector when a connection status has changed to +// 1 -> Remove the +// 2 -> Open new connections if needed +// -> the connectionStatusChanged returns a channel that can be used to wait for a connection status change +type ConnectionStatusListner struct { + c chan struct{} +} + +func NewConnectionStatusListner() (listener *ConnectionStatusListner) { + listener = new(ConnectionStatusListner) + listener.c = make(chan struct{}, 1) + return listener +} + +// onConnectionStatusChanged has to be called when a connection status change +func (listener *ConnectionStatusListner) onConnectionStatusChanged() { + select { + case listener.c <- struct{}{}: + default: + } +} + +// onConnectionStatusChanged return the channel that can be used *BY A SINGLE GOROUTINE* +// to wait for a connection status change +func (listener *ConnectionStatusListner) connectionStatusChanged() chan struct{} { + return listener.c +} diff --git a/common/request.go b/common/request.go index 029023c..9897509 100644 --- a/common/request.go +++ b/common/request.go @@ -5,7 +5,7 @@ import ( "net/url" ) -// HTTPRequest is a serializable version of http.Request ( with only usefull fields ) +// HTTPRequest is a serializable version of net/http.Request ( with only usefull fields ) type HTTPRequest struct { Method string URL string @@ -13,8 +13,8 @@ type HTTPRequest struct { ContentLength int64 } -// SerializeHTTPRequest create a new HTTPRequest from a http.Request -func SerializeHTTPRequest(req *http.Request) (r *HTTPRequest) { +// NewHTTPRequest creates a new HTTPRequest from a net/http.Request instance +func NewHTTPRequest(req *http.Request) (r *HTTPRequest) { r = new(HTTPRequest) r.URL = req.URL.String() r.Method = req.Method @@ -23,8 +23,8 @@ func SerializeHTTPRequest(req *http.Request) (r *HTTPRequest) { return } -// UnserializeHTTPRequest create a new http.Request from a HTTPRequest -func UnserializeHTTPRequest(req *HTTPRequest) (r *http.Request, err error) { +// ToStdLibHTTPRequest creates a new net/http.Request from this HTTPRequest instance +func (req *HTTPRequest) ToStdLibHTTPRequest() (r *http.Request, err error) { r = new(http.Request) r.Method = req.Method r.URL, err = url.Parse(req.URL) diff --git a/common/settings.go b/common/settings.go new file mode 100644 index 0000000..c84237e --- /dev/null +++ b/common/settings.go @@ -0,0 +1,36 @@ +package common + +import ( + "encoding/json" + "fmt" +) + +// ClientSettings are sent to the server by the WSP Client to WSP Server +// The poolSize ( number of idle connection ) is ensured by the server which is the +// only one allowed to close connections +type ClientSettings struct { + ID string // Instance ID + Name string // Hostname ( can be override in the config ) + PoolSize int // Number of idle connection to maintain + ConnectionId uint64 // ID of this specific connection ( should be transmitted in a ConnectionSetting object ? ) +} + +// Unserialize JSON to a new ClientSettings instance +func ClientSettingsFromJson(bytes []byte) (settings *ClientSettings, err error) { + // Deserialize request + settings = new(ClientSettings) + err = json.Unmarshal(bytes, settings) + if err != nil { + return nil, err + } + return settings, nil +} + +// Serialize the ClientSettings to JSON +func (settings *ClientSettings) ToJson() (bytes []byte, err error) { + bytes, err = json.Marshal(settings) + if err != nil { + return nil, fmt.Errorf("Unable to serialize request : %s", err) + } + return bytes, nil +} diff --git a/common/rules.go b/common/validator.go similarity index 57% rename from common/rules.go rename to common/validator.go index 3d90cf8..e3456e5 100644 --- a/common/rules.go +++ b/common/validator.go @@ -1,6 +1,7 @@ package common import ( + "errors" "fmt" "net/http" "regexp" @@ -78,3 +79,55 @@ func (rule *Rule) Match(req *http.Request) bool { func (rule *Rule) String() string { return fmt.Sprintf("%s %s %v", rule.Method, rule.URL, rule.Headers) } + +// Validate a net/http.Request against a Whitelist and a Blacklist +// The blacklist is applied first. If non empty any match in this list will block the request +// Then the whitelist is applied. If non empty, the request must match at least one rule of the whitelist +type RequestValidator struct { + Blacklist []*Rule + Whitelist []*Rule +} + +func (validator *RequestValidator) Initialize() (err error) { + // Compile the rules + for _, rule := range validator.Whitelist { + if err = rule.Compile(); err != nil { + return err + } + } + + for _, rule := range validator.Blacklist { + if err = rule.Compile(); err != nil { + return err + } + } + return nil +} + +// Validate apply the Whitelist and the Blacklist rules to the net/http.Request +func (validator *RequestValidator) Validate(req *http.Request) (err error) { + // Apply blacklist + if len(validator.Blacklist) > 0 { + for _, rule := range validator.Blacklist { + if rule.Match(req) { + return errors.New("Destination is forbidden") + } + } + } + + // Apply whitelist + if len(validator.Whitelist) > 0 { + allowed := false + for _, rule := range validator.Whitelist { + if rule.Match(req) { + allowed = true + break + } + } + if !allowed { + return errors.New("Destination is not allowed") + } + } + + return nil +} diff --git a/server/config.go b/server/config.go index 2ec416e..93fbf72 100644 --- a/server/config.go +++ b/server/config.go @@ -2,6 +2,7 @@ package server import ( "io/ioutil" + "time" "gopkg.in/yaml.v2" @@ -12,8 +13,8 @@ import ( type Config struct { Host string Port int - Timeout int - IdleTimeout int + Timeout time.Duration + IdleTimeout time.Duration Whitelist []*common.Rule Blacklist []*common.Rule SecretKey string @@ -24,10 +25,8 @@ func NewConfig() (config *Config) { config = new(Config) config.Host = "127.0.0.1" config.Port = 8080 - config.Timeout = 1000 - config.IdleTimeout = 60000 - config.Whitelist = make([]*common.Rule, 0) - config.Blacklist = make([]*common.Rule, 0) + config.Timeout = time.Second + config.IdleTimeout = 60 * time.Second return } diff --git a/server/connection.go b/server/connection.go index af58062..8f23eb6 100644 --- a/server/connection.go +++ b/server/connection.go @@ -12,109 +12,146 @@ import ( "github.com/gorilla/websocket" + "errors" "github.com/root-gg/wsp/common" ) // Status of a Connection +type ConnectionStatus int +type WSHandler func(reader io.Reader) error + const ( - IDLE = iota + IDLE ConnectionStatus = iota BUSY CLOSED ) -// Connection manage a single websocket connection from +// Connection manage a single WebSocket connection from a WSP client type Connection struct { - pool *Pool - ws *websocket.Conn - status int - idleSince time.Time - lock sync.Mutex - nextResponse chan chan io.Reader -} + id uint64 + ws *websocket.Conn + + status ConnectionStatus + lock sync.RWMutex -// NewConnection return a new Connection -func NewConnection(pool *Pool, ws *websocket.Conn) (connection *Connection) { - connection = new(Connection) - connection.pool = pool - connection.ws = ws - connection.nextResponse = make(chan chan io.Reader) + nextReader chan func(io.Reader) + + releaser func(conn *Connection) + idleSince time.Time + + done chan struct{} +} - connection.Release() +// newConnection return a new Connection +func newConnection(id uint64, ws *websocket.Conn, releaser func(conn *Connection)) (conn *Connection) { + conn = new(Connection) + conn.id = id + conn.ws = ws + conn.releaser = releaser + conn.nextReader = make(chan func(io.Reader), 1) + conn.done = make(chan struct{}) - go connection.read() + go conn.read() return } -// read the incoming message of the connection -func (connection *Connection) read() { +// Get the status of the connection in a concurrently safe way +func (conn *Connection) getStatus() ConnectionStatus { + conn.lock.RLock() + defer conn.lock.RUnlock() + return conn.status +} + +// Handle next message pass a function to process the next WebSocket message +// to the read goroutine. Only one message can be handled at a time. +// This method blocks until the handler has returned. +func (conn *Connection) handleNextMessage(h WSHandler) error { + done := make(chan error) + h2 := func(reader io.Reader) { + done <- h(reader) + } + + select { + case conn.nextReader <- h2: + case <-conn.done: + return errors.New("connection closed") + default: + return errors.New("already reading") + } + + select { + case err := <-done: + return err + case <-conn.done: + return errors.New("connection closed") + } +} + +// read the incoming message of the WebSocket connection +func (conn *Connection) read() { defer func() { if r := recover(); r != nil { log.Printf("Websocket crash recovered : %s", r) } - connection.Close() + conn.close() }() for { - if connection.status == CLOSED { - break - } - // https://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages // // We need to ensure : // - no concurrent calls to ws.NextReader() / ws.ReadMessage() // - only one reader exists at a time // - wait for reader to be consumed before requesting the next one - // - always be reading on the socket to be able to process control messages ( ping / pong / close ) + // - always be reading on the socket to be able to process control messages ( ping / pong / closeNoLock ) // We will block here until a message is received or the ws is closed - _, reader, err := connection.ws.NextReader() + _, ioReader, err := conn.ws.NextReader() if err != nil { + if !conn.isClosed() { + log.Printf("WebSocket error : %s", err) + } break } - if connection.status != BUSY { + if conn.getStatus() != BUSY { // We received a wild unexpected message + log.Printf("Unexpected wild message received") break } - // We received a message from the proxy - // It is expected to be either a HttpResponse or a HttpResponseBody - // We wait for proxyRequest to send a channel to get the message - c := <-connection.nextResponse - if c == nil { - // We have been unlocked by Close() + select { + case f := <-conn.nextReader: + f(ioReader) + // Ensure we have consumed the all the ioReader + _, err = ioutil.ReadAll(ioReader) + if err != nil { + log.Printf("Unable to clean io reader") + break + } + case <-conn.done: break } - - // Send the reader back to proxyRequest - c <- reader - - // Wait for proxyRequest to close the channel - // this notify that it is done with the reader - <-c } } -// Proxy a HTTP request through the Proxy over the websocket connection -func (connection *Connection) proxyRequest(w http.ResponseWriter, r *http.Request) (err error) { - log.Printf("proxy request to %s", connection.pool.id) - +// Proxy a HTTP request through the Proxy over the WebSocket connection +func (conn *Connection) proxyRequest(w http.ResponseWriter, r *http.Request) (err error) { // Serialize HTTP request - jsonReq, err := json.Marshal(common.SerializeHTTPRequest(r)) + jsonReq, err := json.Marshal(common.NewHTTPRequest(r)) if err != nil { return fmt.Errorf("Unable to serialize request : %s", err) } // Send the serialized HTTP request to the remote Proxy - err = connection.ws.WriteMessage(websocket.TextMessage, jsonReq) + err = conn.ws.WriteMessage(websocket.TextMessage, jsonReq) if err != nil { return fmt.Errorf("Unable to write request : %s", err) } // Pipe the HTTP request body to the remote Proxy - bodyWriter, err := connection.ws.NextWriter(websocket.BinaryMessage) + bodyWriter, err := conn.ws.NextWriter(websocket.BinaryMessage) if err != nil { return fmt.Errorf("Unable to get request body writer : %s", err) } @@ -124,131 +161,109 @@ func (connection *Connection) proxyRequest(w http.ResponseWriter, r *http.Reques } err = bodyWriter.Close() if err != nil { - return fmt.Errorf("Unable to pipe request body (close) : %s", err) - } - - // Get the serialized HTTP Response from the remote Proxy - // To do so send a new channel to the read() goroutine - // to get the next message reader - responseChannel := make(chan (io.Reader)) - connection.nextResponse <- responseChannel - responseReader, more := <-responseChannel - if responseReader == nil { - if more { - // If more is false the channel is already closed - close(responseChannel) - } - return fmt.Errorf("Unable to get http response reader : %s", err) + return fmt.Errorf("Unable to pipe request body (closeNoLock) : %s", err) } - // Read the HTTP Response - jsonResponse, err := ioutil.ReadAll(responseReader) - if err != nil { - close(responseChannel) - return fmt.Errorf("Unable to read http response : %s", err) - } + err = conn.handleNextMessage(func(reader io.Reader) (err error) { + + // Deserialize the HTTP Response + httpResponse := new(common.HTTPResponse) + err = json.NewDecoder(reader).Decode(httpResponse) + if err != nil { + return fmt.Errorf("Unable to unserialize http response : %s", err) + } - // Notify the read() goroutine that we are done reading the response - close(responseChannel) + // Write response headers back to the client + for header, values := range httpResponse.Header { + for _, value := range values { + w.Header().Add(header, value) + } + } - // Deserialize the HTTP Response - httpResponse := new(common.HTTPResponse) - err = json.Unmarshal(jsonResponse, httpResponse) + w.WriteHeader(httpResponse.StatusCode) + + return nil + }) if err != nil { - return fmt.Errorf("Unable to unserialize http response : %s", err) + return fmt.Errorf("Unable to handle request : %s", err) } - // Write response headers back to the client - for header, values := range httpResponse.Header { - for _, value := range values { - w.Header().Add(header, value) - } - } - w.WriteHeader(httpResponse.StatusCode) - - // Get the HTTP Response body from the remote Proxy - // To do so send a new channel to the read() goroutine - // to get the next message reader - responseBodyChannel := make(chan (io.Reader)) - connection.nextResponse <- responseBodyChannel - responseBodyReader, more := <-responseBodyChannel - if responseBodyReader == nil { - if more { - // If more is false the channel is already closed - close(responseChannel) + err = conn.handleNextMessage(func(reader io.Reader) (err error) { + // Pipe the HTTP response body right from the remote Proxy to the client + _, err = io.Copy(w, reader) + if err != nil { + return fmt.Errorf("Unable to pipe response body : %s", err) } - return fmt.Errorf("Unable to get http response body reader : %s", err) - } - // Pipe the HTTP response body right from the remote Proxy to the client - _, err = io.Copy(w, responseBodyReader) + return nil + }) if err != nil { - close(responseBodyChannel) - return fmt.Errorf("Unable to pipe response body : %s", err) + return fmt.Errorf("Unable to handle request body : %s", err) } - // Notify read() that we are done reading the response body - close(responseBodyChannel) + // Put the connection back in the pool + conn.release() - connection.Release() - - return + return nil } // Take notifies that this connection is going to be used -func (connection *Connection) Take() bool { - connection.lock.Lock() - defer connection.lock.Unlock() +func (conn *Connection) take() bool { + conn.lock.Lock() + defer conn.lock.Unlock() - if connection.status == CLOSED { + if conn.isClosed() { return false } - if connection.status == BUSY { + if conn.status != IDLE { return false } - connection.status = BUSY + conn.status = BUSY + return true } // Release notifies that this connection is ready to use again -func (connection *Connection) Release() { - connection.lock.Lock() - defer connection.lock.Unlock() +func (conn *Connection) release() { + conn.lock.Lock() + defer conn.lock.Unlock() - if connection.status == CLOSED { + if conn.isClosed() { return } - connection.idleSince = time.Now() - connection.status = IDLE + conn.idleSince = time.Now() + conn.status = IDLE - go connection.pool.Offer(connection) + // Add the connection back to the pool + conn.releaser(conn) } -// Close the connection -func (connection *Connection) Close() { - connection.lock.Lock() - defer connection.lock.Unlock() - - connection.close() +// IsClosed return true if the connection has been closed +func (conn *Connection) isClosed() bool { + select { + case <-conn.done: + return true + default: + return false + } } -// Close the connection ( without lock ) -func (connection *Connection) close() { - if connection.status == CLOSED { +// Close the connection +func (conn *Connection) close() { + conn.lock.Lock() + defer conn.lock.Unlock() + + if conn.isClosed() { return } - log.Printf("Closing connection from %s", connection.pool.id) - - // This one will be executed *before* lock.Unlock() - defer func() { connection.status = CLOSED }() + conn.status = CLOSED - // Unlock a possible read() wild message - close(connection.nextResponse) + close(conn.done) - // Close the underlying TCP connection - connection.ws.Close() + // Close the underlying TCP conn + conn.ws.Close() } diff --git a/server/pool.go b/server/pool.go index c348d63..09fed47 100644 --- a/server/pool.go +++ b/server/pool.go @@ -6,102 +6,157 @@ import ( "time" "github.com/gorilla/websocket" + + "github.com/root-gg/wsp/common" ) // Pool handle all connections from a remote Proxy type Pool struct { server *Server - id string - size int + clientSettings *common.ClientSettings + + // This channel provides idle connection to the server + // The server must then call Take() to make sure it is + // still open and make it ready to use + idle chan *Connection - connections []*Connection - idle chan *Connection + // This map is only here to provide a way to display statistics + // about the connections in the pool + connections map[*Connection]struct{} + connectionsLock sync.Mutex - done bool + done chan struct{} lock sync.RWMutex } // NewPool creates a new Pool -func NewPool(server *Server, id string) (pool *Pool) { +func NewPool(server *Server, clientSettings *common.ClientSettings) (pool *Pool) { pool = new(Pool) pool.server = server - pool.id = id + pool.clientSettings = clientSettings pool.idle = make(chan *Connection) + pool.done = make(chan struct{}) + pool.connections = make(map[*Connection]struct{}) + return } // Register creates a new Connection and adds it to the pool -func (pool *Pool) Register(ws *websocket.Conn) { +func (pool *Pool) register(id uint64, ws *websocket.Conn) { pool.lock.Lock() defer pool.lock.Unlock() // Ensure we never add a connection to a pool we have garbage collected - if pool.done { + if pool.isClosed() { return } - log.Printf("Registering new connection from %s", pool.id) - connection := NewConnection(pool, ws) - pool.connections = append(pool.connections, connection) + log.Printf("Registering new connection %d from %s (%s)", id, pool.clientSettings.Name, pool.clientSettings.ID) + + connection := newConnection(id, ws, pool.offer) + + // Keep track of the connection to be able to display statistics + pool.connectionsLock.Lock() + pool.connections[connection] = struct{}{} + pool.connectionsLock.Unlock() + + // Automatically remove connection from the map on close + go func() { + <-connection.done + log.Printf("Connection %d from %s (%s) has been closed", id, pool.clientSettings.Name, pool.clientSettings.ID) + pool.connectionsLock.Lock() + delete(pool.connections, connection) + pool.connectionsLock.Unlock() + }() + + pool.offer(connection) return } // Offer an idle connection to the server -func (pool *Pool) Offer(connection *Connection) { - go func() { pool.idle <- connection }() +func (pool *Pool) offer(connection *Connection) { + go func() { + select { + case pool.idle <- connection: + case <-connection.done: + case <-pool.done: + } + }() } -// Clean removes dead connection from the pool -// Look for dead connection in the pool -// This MUST be surrounded by pool.lock.Lock() -func (pool *Pool) Clean() { - idle := 0 +// Clean tries to keep at most poolSize idle connection in the pool. +// Connections are left open for Config.IdleTimeout before being closed. +// Only the server is allowed to close connection to avoid the client +// to close a connection about to be used to proxy a request. +func (pool *Pool) clean() { + pool.lock.Lock() + defer pool.lock.Unlock() + var connections []*Connection - for _, connection := range pool.connections { - // We need to be sur we'll never close a BUSY or soon to be BUSY connection - connection.lock.Lock() - if connection.status == IDLE { - idle++ - if idle > pool.size { - // We have enough idle connections in the pool. - // Terminate the connection if it is idle since more that IdleTimeout - if int(time.Now().Sub(connection.idleSince).Seconds())*1000 > pool.server.Config.IdleTimeout { - connection.close() +LOOP: + for { + select { + case conn := <-pool.idle: + if conn.getStatus() != IDLE { + continue + } + if len(connections) < pool.clientSettings.PoolSize { + connections = append(connections, conn) + } else { + if time.Now().Sub(conn.idleSince) > pool.server.Config.IdleTimeout { + conn.close() + } else { + connections = append(connections, conn) } } + default: + break LOOP } - connection.lock.Unlock() - if connection.status == CLOSED { - continue - } - connections = append(connections, connection) } - pool.connections = connections -} -// IsEmpty clean the pool and return true if the pool is empty -func (pool *Pool) IsEmpty() bool { - pool.lock.Lock() - defer pool.lock.Unlock() + for _, conn := range connections { + pool.offer(conn) + } - pool.Clean() - return len(pool.connections) == 0 + return +} + +// isClosed returns true if the pool had been closed +func (pool *Pool) isClosed() bool { + select { + case <-pool.done: + return true + default: + return false + } } -// Shutdown closes every connections in the pool and cleans it -func (pool *Pool) Shutdown() { +// Close every connections in the pool and clean it +func (pool *Pool) close() { pool.lock.Lock() defer pool.lock.Unlock() - pool.done = true + if pool.isClosed() { + log.Println("pool alreadey closed") + return + } + + close(pool.done) - for _, connection := range pool.connections { - connection.Close() +LOOP: + for { + //log.Println("empty idle chan") + select { + case conn := <-pool.idle: + conn.close() + default: + break LOOP + } } - pool.Clean() + log.Println("pool closed") } // PoolSize is the number of connection in each state in the pool @@ -109,22 +164,25 @@ type PoolSize struct { Idle int Busy int Closed int + Total int } // Size return the number of connection in each state in the pool func (pool *Pool) Size() (ps *PoolSize) { - pool.lock.Lock() - defer pool.lock.Unlock() + pool.connectionsLock.Lock() + defer pool.connectionsLock.Unlock() ps = new(PoolSize) - for _, connection := range pool.connections { - if connection.status == IDLE { + for connection := range pool.connections { + status := connection.getStatus() + if status == IDLE { ps.Idle++ - } else if connection.status == BUSY { + } else if status == BUSY { ps.Busy++ - } else if connection.status == CLOSED { + } else if status == CLOSED { ps.Closed++ } + ps.Total++ } return diff --git a/server/server.go b/server/server.go index 9f3916b..c9d0f01 100644 --- a/server/server.go +++ b/server/server.go @@ -1,17 +1,18 @@ package server import ( + "fmt" "log" "math/rand" "net/http" "net/url" "reflect" "strconv" - "strings" "sync" "time" "github.com/gorilla/websocket" + "github.com/root-gg/wsp/common" ) @@ -21,31 +22,14 @@ import ( type Server struct { Config *Config - upgrader websocket.Upgrader + validator *common.RequestValidator + upgrader websocket.Upgrader + httpServer *http.Server pools []*Pool - lock sync.RWMutex - done chan struct{} - - dispatcher chan *ConnectionRequest - - server *http.Server -} -// ConnectionRequest is used to request a proxy connection from the dispatcher -type ConnectionRequest struct { - connection chan *Connection - timeout <-chan time.Time -} - -// NewConnectionRequest creates a new connection request -func NewConnectionRequest(timeout time.Duration) (cr *ConnectionRequest) { - cr = new(ConnectionRequest) - cr.connection = make(chan *Connection) - if timeout > 0 { - cr.timeout = time.After(timeout) - } - return + lock sync.RWMutex + done chan struct{} } // NewServer return a new Server instance @@ -54,21 +38,31 @@ func NewServer(config *Config) (server *Server) { server = new(Server) server.Config = config + + server.validator = &common.RequestValidator{ + Whitelist: config.Whitelist, + Blacklist: config.Blacklist, + } + err := server.validator.Initialize() + if err != nil { + log.Fatalf("Unable to initialize the request validator : %s", err) + } + server.upgrader = websocket.Upgrader{} server.done = make(chan struct{}) - server.dispatcher = make(chan *ConnectionRequest) return } // Start Server HTTP server func (server *Server) Start() { go func() { + ticker := time.NewTicker(5 * time.Second) for { select { case <-server.done: break - case <-time.After(5 * time.Second): + case <-ticker.C: server.clean() } } @@ -79,10 +73,8 @@ func (server *Server) Start() { r.HandleFunc("/register", server.register) r.HandleFunc("/status", server.status) - go server.dispatchConnections() - - server.server = &http.Server{Addr: server.Config.Host + ":" + strconv.Itoa(server.Config.Port), Handler: r} - go func() { log.Fatal(server.server.ListenAndServe()) }() + server.httpServer = &http.Server{Addr: server.Config.Host + ":" + strconv.Itoa(server.Config.Port), Handler: r} + go func() { log.Fatal(server.httpServer.ListenAndServe()) }() } // clean remove empty Pools @@ -99,16 +91,17 @@ func (server *Server) clean() { var pools []*Pool for _, pool := range server.pools { - if pool.IsEmpty() { - log.Printf("Removing empty connection pool : %s", pool.id) - pool.Shutdown() + pool.clean() + poolSize := pool.Size() + if poolSize.Total == 0 { + log.Printf("Removing empty connection pool : %s (%s)", pool.clientSettings.Name, pool.clientSettings.ID) + pool.close() } else { pools = append(pools, pool) } - ps := pool.Size() - idle += ps.Idle - busy += ps.Busy + idle += poolSize.Idle + busy += poolSize.Busy } log.Printf("%d pools, %d idle, %d busy", len(pools), idle, busy) @@ -116,137 +109,119 @@ func (server *Server) clean() { server.pools = pools } -// Dispatch connection from available pools to clients requests -func (server *Server) dispatchConnections() { - for { - // A client requests a connection - request, ok := <-server.dispatcher - if !ok { - // Shutdown - break - } - - for { - server.lock.RLock() - - if len(server.pools) == 0 { - // No connection pool available - server.lock.RUnlock() - break - } - - // Build a select statement dynamically - cases := make([]reflect.SelectCase, len(server.pools)+1) +// Get a timeout timer to get a connection +func (server *Server) getTimeout() <-chan time.Time { + if server.Config.Timeout > 0 { + return time.After(server.Config.Timeout) + } + return nil +} - // Add all pools idle connection channel - for i, ch := range server.pools { - cases[i] = reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(ch.idle)} - } +// Get a ws connection from one of the available pools +func (server *Server) getConnection() *Connection { + timeout := server.getTimeout() + for { + server.lock.RLock() + poolCount := len(server.pools) + server.lock.RUnlock() - // Add timeout channel - if request.timeout != nil { - cases[len(cases)-1] = reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(request.timeout)} - } else { - cases[len(cases)-1] = reflect.SelectCase{ - Dir: reflect.SelectDefault} + if poolCount == 0 { + // No connection pool available + select { + case <-timeout: + // a timeout occurred + return nil + default: + time.Sleep(10 * time.Millisecond) + continue } + } - server.lock.RUnlock() + // Build a select statement dynamically + // This allows to wait on multiple connection pools for the next idle connection + var cases []reflect.SelectCase + + // Add all pools idle connection channel + // range makes a copy so no need to lock + server.lock.RLock() + for _, ch := range server.pools { + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(ch.idle)}) + } + server.lock.RUnlock() - // This call blocks - chosen, value, ok := reflect.Select(cases) + // Add a timeout channel or default case + if timeout != nil { + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(timeout)}) + } else { + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectDefault}) + } - if chosen == len(cases)-1 { - // a timeout occured - break - } - if !ok { - // a proxy pool has been removed, try again - continue - } - connection, _ := value.Interface().(*Connection) + chosen, value, ok := reflect.Select(cases) - // Verify that we can use this connection - if connection.Take() { - request.connection <- connection - break - } + if chosen == len(cases)-1 { + // a timeout occurred + return nil } + if !ok { + // a proxy pool has been removed, try again + continue + } + connection, _ := value.Interface().(*Connection) - close(request.connection) + // Verify that we can use this connection + if connection.take() { + return connection + } } + + return nil } // This is the way for clients to execute HTTP requests through an Proxy -func (server *Server) request(w http.ResponseWriter, r *http.Request) { +func (server *Server) request(resp http.ResponseWriter, req *http.Request) { // Parse destination URL - dstURL := r.Header.Get("X-PROXY-DESTINATION") + dstURL := req.Header.Get("X-PROXY-DESTINATION") if dstURL == "" { - common.ProxyErrorf(w, "Missing X-PROXY-DESTINATION header") + common.ProxyErrorf(resp, "Missing X-PROXY-DESTINATION header") return } URL, err := url.Parse(dstURL) if err != nil { - common.ProxyErrorf(w, "Unable to parse X-PROXY-DESTINATION header") + common.ProxyErrorf(resp, "Unable to parse X-PROXY-DESTINATION header") return } - r.URL = URL + req.URL = URL - log.Printf("[%s] %s", r.Method, r.URL.String()) - - // Apply blacklist - if len(server.Config.Blacklist) > 0 { - for _, rule := range server.Config.Blacklist { - if rule.Match(r) { - common.ProxyErrorf(w, "Destination is forbidden") - return - } - } - } + log.Printf("[%s] %s", req.Method, req.URL.String()) - // Apply whitelist - if len(server.Config.Whitelist) > 0 { - allowed := false - for _, rule := range server.Config.Whitelist { - if rule.Match(r) { - allowed = true - break - } - } - if !allowed { - common.ProxyErrorf(w, "Destination is not allowed") - return - } - } - - if len(server.pools) == 0 { - common.ProxyErrorf(w, "No proxy available") + err = server.validator.Validate(req) + if err != nil { + common.ProxyErrorf(resp, fmt.Sprintf("Invalid request : %s", err.Error())) return } // Get a proxy connection - request := NewConnectionRequest(time.Duration(server.Config.Timeout) * time.Millisecond) - server.dispatcher <- request - connection := <-request.connection + connection := server.getConnection() if connection == nil { - common.ProxyErrorf(w, "Unable to get a proxy connection") + common.ProxyErrorf(resp, "Unable to get a proxy connection") return } // Send the request to the proxy - err = connection.proxyRequest(w, r) + err = connection.proxyRequest(resp, req) if err != nil { // An error occurred throw the connection away log.Println(err) - connection.Close() + connection.close() // Try to return an error to the client // This might fail if response headers have already been sent - common.ProxyError(w, err) + common.ProxyError(resp, err) } } @@ -267,17 +242,15 @@ func (server *Server) register(w http.ResponseWriter, r *http.Request) { // The first message should contains the remote Proxy name and size _, greeting, err := ws.ReadMessage() if err != nil { - common.ProxyErrorf(w, "Unable to read greeting message : %s", err) + common.ProxyErrorf(w, "Unable to read client settings : %s", err) ws.Close() return } // Parse the greeting message - split := strings.Split(string(greeting), "_") - id := split[0] - size, err := strconv.Atoi(split[1]) + clientSettings, err := common.ClientSettingsFromJson(greeting) if err != nil { - common.ProxyErrorf(w, "Unable to parse greeting message : %s", err) + common.ProxyErrorf(w, "Unable to parse client settings : %s", err) ws.Close() return } @@ -288,33 +261,44 @@ func (server *Server) register(w http.ResponseWriter, r *http.Request) { // Get that client's Pool var pool *Pool for _, p := range server.pools { - if p.id == id { + if p.clientSettings.ID == clientSettings.ID { pool = p break } } if pool == nil { - pool = NewPool(server, id) + pool = NewPool(server, clientSettings) server.pools = append(server.pools, pool) } - // update pool size - pool.size = size - // Add the ws to the pool - pool.Register(ws) + pool.register(clientSettings.ConnectionId, ws) } func (server *Server) status(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) } -// Shutdown stop the Server +func (server *Server) IsClosed() bool { + select { + case <-server.done: + return true + default: + return false + } +} + +// Close stop the WSP Server func (server *Server) Shutdown() { + if server.IsClosed() { + return + } close(server.done) - close(server.dispatcher) + + server.lock.RLock() for _, pool := range server.pools { - pool.Shutdown() + pool.close() } + server.lock.RUnlock() server.clean() } diff --git a/wsp_client/wsp_client.cfg b/wsp_client/wsp_client.cfg index 6875500..eb3848a 100644 --- a/wsp_client/wsp_client.cfg +++ b/wsp_client/wsp_client.cfg @@ -3,14 +3,15 @@ targets : # Endpoints to connect to - ws://127.0.0.1:8080/register # poolidlesize : 10 # Default number of concurrent open (TCP) connections to keep idle per WSP server poolmaxsize : 100 # Maximum number of concurrent open (TCP) connections per WSP server +#insecureSkipVerify : true # Disable the http client certificate chain and hostname verification #blacklist : # Forbidden destination ( deny nothing if empty ) # - method : ".*" # Applied in order before whitelist # url : ".*forbidden.*" # None must match # headers : # Optinal header check -# X-CUSTOM-HEADER : "^value$" # +# X-CUSTOM-HEADER : "^value$" # #whitelist : # Allowed destinations ( allow all if empty ) # - method : "^GET$" # Applied in order after blacklist # url : "http(s)?://.*$" # One must match # headers : # Optinal header check -# X-CUSTOM-HEADER : "^value$" # +# X-CUSTOM-HEADER : "^value$" # # secretkey : ThisIsASecret # secret key that must match the value set in servers configuration diff --git a/wsp_client/wsp_client.go b/wsp_client/wsp_client.go index 0586e80..1edf42a 100644 --- a/wsp_client/wsp_client.go +++ b/wsp_client/wsp_client.go @@ -7,7 +7,6 @@ import ( "os/signal" "github.com/root-gg/utils" - "github.com/root-gg/wsp/client" ) diff --git a/wsp_server/wsp_server.cfg b/wsp_server/wsp_server.cfg index 6e50a41..1ad8d07 100644 --- a/wsp_server/wsp_server.cfg +++ b/wsp_server/wsp_server.cfg @@ -1,16 +1,16 @@ --- host : 127.0.0.1 # Address to bind the HTTP server port : 8080 # Port to bind the HTTP server -timeout : 1000 # Time to wait before acquiring a WS connection to forward the request (milliseconds) -idletimeout : 60000 # Time to wait before closing idle connection when there is enough idle connections (milliseconds) +timeout : 1s # Time to wait before acquiring a WS connection to forward the request (milliseconds) +idletimeout : 60s # Time to wait before closing idle connection when there is enough idle connections (milliseconds) #blacklist : # Forbidden destination ( deny nothing if empty ) # - method : ".*" # Applied in order before whitelist # url : "^http(s)?://google.*" # None must match # headers : # Optinal header check -# X-CUSTOM-HEADER : "^value$" # +# X-CUSTOM-HEADER : "^value$" # #whitelist : # Allowed destinations ( allow all if empty ) # - method : "^GET$" # Applied in order after blacklist # url : "^http(s)?://.*$" # One must match # headers : # Optinal header check -# X-CUSTOM-HEADER : "^value$" # -# secretkey : ThisIsASecret # secret key that must be set in clients configuration +# X-CUSTOM-HEADER : "^value$" # +#secretkey : ThisIsASecret # secret key that must be set in clients configuration diff --git a/wsp_server/wsp_server.go b/wsp_server/wsp_server.go index 3a8a95c..64bc7a4 100644 --- a/wsp_server/wsp_server.go +++ b/wsp_server/wsp_server.go @@ -7,11 +7,11 @@ import ( "os/signal" "github.com/root-gg/utils" - "github.com/root-gg/wsp/server" ) func main() { + configFile := flag.String("config", "wsp_server.cfg", "config file path") flag.Parse()