diff --git a/app/server.go b/app/server.go index dd448df..f77ec27 100644 --- a/app/server.go +++ b/app/server.go @@ -13,52 +13,54 @@ import ( ) type Server struct { - apps IServices + services IServices defaultDB *sql.DB + host string + port string + openConns chan<- net.Conn } -func NewServer() *Server { - // Connect to the default db - defaultDB, err := sql.Open("sqlite3", config.DEFAULT_DB) - if err != nil { - log.Panicln("Could not connect DB:", err) - } - +func NewServer(connCh chan<- net.Conn, defaultDB *sql.DB) *Server { // initialize and start running Services - apps := InitServices(defaultDB) - apps.ServeServices() + services := InitServices(defaultDB) return &Server{ - apps: apps, + services: services, defaultDB: defaultDB, + openConns: connCh, } } func (s *Server) process(conn net.Conn) { // Get the servers - addrs := s.apps.GetServiceServers(conn.RemoteAddr().String()) + addrs := s.services.GetServiceServers(conn.RemoteAddr().String()) + // log.Println(conn.RemoteAddr().String()) if len(addrs) <= 0 { // If the remote address is unknown, redirect to the welcome server - addrs = s.apps.GetServiceServers("[::1]:80") + addrs = s.services.GetServiceServers(fmt.Sprintf("%s:%s", s.host, s.port)) } - // Get next server from load balancer + // TODO: + // The idea for the load balancer isn't fully formed yet. + // For now, it's always going to select the first server for every connection lb := tools.NewLoadBalancer(addrs) addr := lb.GetNextServer() // Connect to the available server - localConn, err := s.apps.ConnectToServer(addr) + localConn, err := s.services.ConnectToServer(addr) if err != nil { log.Println(err) return } + // Add local conn to open connections channel + s.openConns <- localConn + // Establish a point-to-point connection between conoid server and app's local server go func() { for { _, err = io.Copy(localConn, conn) if err != nil { - log.Println("Failed to read from remote connection:", err) break } } @@ -68,7 +70,6 @@ func (s *Server) process(conn net.Conn) { for { _, err = io.Copy(conn, localConn) if err != nil { - log.Println("Failed to write to remote connection:", err) break } } @@ -77,21 +78,42 @@ func (s *Server) process(conn net.Conn) { func (s *Server) Serve() { // Start the server and wait for connections - listener, err := net.Listen("tcp", fmt.Sprintf("[::]:%d", config.TCP_PORT)) + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", config.TCP_PORT)) + if err != nil { + log.Println(err) + return + } + host, port, err := net.SplitHostPort(listener.Addr().String()) if err != nil { log.Println(err) return } - log.Printf("Conoid started and listening on port %d\n", config.TCP_PORT) + s.host = host + s.port = port + log.Printf("Conoid listening on host: %s, port %s\n", host, port) + + // Start running services + s.services.ServeServices(host, port, s.openConns) + + // Record connections to ensure it doesn't exceed the max size + connsCh := make(chan int, config.MAX_CONN_COUNT) for { + // Block if connections count it full + connsCh <- 1 + // Establish a point-to-point connection between the client and server conn, err := listener.Accept() if err != nil { log.Println("Connection failed:", err) + // Remove record + <-connsCh continue } + // Add to open connections + s.openConns <- conn + // Handle connection in a new goroutine go s.process(conn) } diff --git a/app/services.go b/app/services.go index a69cf45..d63926f 100644 --- a/app/services.go +++ b/app/services.go @@ -7,7 +7,7 @@ import ( "net" "net/http" - // "github.com/DeeStarks/conoid/app/tools" + "github.com/DeeStarks/conoid/app/tools" port "github.com/DeeStarks/conoid/domain/ports" ) @@ -23,11 +23,11 @@ type ( } IServices interface { - ServeServices() // Retrieve all Services and serve - GetRunningServices() RunningServices // Get all running Services - GetServiceServers(string) []string // Get all servers' address that a service runs on - ConnectToServer(string) (net.Conn, error) // Connect to a service running locally - ServeStatic(string) int // Serve static Services, and return their port numbers + ServeServices(string, string, chan<- net.Conn) // Retrieve all Services and serve + GetRunningServices() RunningServices // Get all running Services + GetServiceServers(string) []string // Get all servers' address that a service runs on + ConnectToServer(string) (net.Conn, error) // Connect to a service running locally + ServeStatic(string) (string, string) // Serve static Services, and return their port numbers } ) @@ -40,28 +40,28 @@ func InitServices(defaultDB *sql.DB) IServices { } // Retrieve all Services and serve -func (s *Services) ServeServices() { +func (s *Services) ServeServices(conoidHost, conoidPort string, connCh chan<- net.Conn) { // Serve the welcome page - welcomePort := s.ServeStatic("./assets/welcome/") - // The welcome page will be served by default on port 80 - s.running["[::1]:80"] = []string{fmt.Sprintf("[::1]:%d", welcomePort)} - s.running["localhost:80"] = []string{fmt.Sprintf("[::1]:%d", welcomePort)} - s.running["localhost"] = []string{fmt.Sprintf("[::1]:%d", welcomePort)} - s.running["127.0.0.1:80"] = []string{fmt.Sprintf("[::1]:%d", welcomePort)} - s.running["127.0.0.1"] = []string{fmt.Sprintf("[::1]:%d", welcomePort)} - - // Serve + welcomeHost, welcomePort := s.ServeStatic("./assets/welcome/") + // Set the welcome page as the werver's default page + s.running[fmt.Sprintf("%s:%s", conoidHost, conoidPort)] = []string{fmt.Sprintf("%s:%s", welcomeHost, welcomePort)} + + // Serve registered running services dbPort := port.NewDomainPort(s.defaultDB) services, err := dbPort.ServiceProcesses().RetrieveRunning() if err != nil { - log.Println("Could not serve apps:", err) + log.Println("Could not serve:", err) + return } for _, service := range services { + // Addresses the service is running on + var serverAddrs []string + if service.Type == "static" { // Serve static - portNo := s.ServeStatic(service.RootDirectory) - addr := fmt.Sprintf("127.0.0.1:%d", portNo) + host, port := s.ServeStatic(service.RootDirectory) + addr := fmt.Sprintf("%s:%s", host, port) _, err := dbPort.ServiceProcesses().Update(service.Name, map[string]interface{}{ "listeners": addr, }) @@ -69,23 +69,49 @@ func (s *Services) ServeServices() { log.Println("Could not update service state:", err) } s.running[service.RemoteServer] = []string{addr} - } else if service.Type == "server" { + serverAddrs = []string{addr} + } else { servers := []string{} // Connect to all listening servers for _, addr := range service.Listeners { _, err := s.ConnectToServer(addr) if err != nil { - log.Printf("Could not connect to server address: %s; Error: %v\n", addr, err) + log.Printf("Could not connect to: %s; Stopping...\n", addr) + // Update service state + dbPort.ServiceProcesses().Update(service.Name, map[string]interface{}{ + "status": 0, + }) continue } // Append servers to listening servers servers = append(servers, addr) } s.running[service.RemoteServer] = servers - + serverAddrs = servers } // Tunnelling + if service.Tunnelled { + tunnel := tools.NewTunnel(service.Name, connCh) + host, err := tunnel.AllocateHost() + if err != nil { + log.Println("Error allocating tunnel remote host. Ensure your device is connected to the internet") + continue + } + + for i := 0; i < host.MaxConnectionCount(); i++ { + go host.OpenTunnel(fmt.Sprintf("%s:%s", conoidHost, conoidPort), serverAddrs) + } + + // Update service's remote_server + _, err = dbPort.ServiceProcesses().Update(service.Name, map[string]interface{}{ + "remote_server": host.FullURL(), + }) + if err != nil { + log.Println("Error updating tunnel state:", err) + continue + } + } } } @@ -109,23 +135,25 @@ func (s *Services) ConnectToServer(addr string) (net.Conn, error) { } // Serve static Services, and return their port numbers -func (s *Services) ServeStatic(dir string) int { +func (s *Services) ServeStatic(dir string) (string, string) { fs := http.FileServer(http.Dir(dir)) mux := http.NewServeMux() mux.Handle("/", fs) // Get and listen on the next port number - portNo := s.nextPN + host := "0.0.0.0" + port := s.nextPN for { // Dial the port number to see if it's available - _, err := net.Dial("tcp", fmt.Sprintf("[::]:%d", portNo)) + _, err := net.Dial("tcp", fmt.Sprintf("%s:%d", host, port)) if err != nil { - go http.ListenAndServe(fmt.Sprintf(":%d", portNo), mux) + // If it's not in use, serve + go http.ListenAndServe(fmt.Sprintf("%s:%d", host, port), mux) break } // If it's already in use, try the next port - portNo++ + port++ } - s.nextPN = portNo - return portNo + s.nextPN = port + return host, fmt.Sprintf("%d", port) } diff --git a/app/tools/tunnel.go b/app/tools/tunnel.go new file mode 100644 index 0000000..c46865a --- /dev/null +++ b/app/tools/tunnel.go @@ -0,0 +1,142 @@ +package tools + +import ( + "encoding/json" + "fmt" + "io" + + "io/ioutil" + "log" + "net" + "net/http" + "net/url" +) + +type ( + ITunnel interface { + AllocateHost() (ITunnelHost, error) + } + + tunnel struct { + name string + tunnelServer string + openConns chan<- net.Conn + } + + ITunnelHost interface { + OpenTunnel(string, []string) + SubDomain() string + FullURL() string + MaxConnectionCount() int + PortNumber() int + } + + allocatedHost struct { + Id string `json:"id"` + Port int `json:"port"` + MaxConnCount int `json:"max_conn_count"` + Url string `json:"url"` + openConns chan<- net.Conn + } +) + +// Create a new tunnel +func NewTunnel(name string, connCh chan<- net.Conn) ITunnel { + return &tunnel{ + name: name, + tunnelServer: "http://localtunnel.me/", + openConns: connCh, + } +} + +// Creates a remote host for the service to be tunnelled +func (t *tunnel) AllocateHost() (ITunnelHost, error) { + var host allocatedHost + host.openConns = t.openConns + + // Names are required to be at least 4 in length + subdomain := t.name + if len(t.name) < 4 { + suffix := []byte("1234") + subdomain += string(suffix[:4-len(t.name)]) + } + + res, err := http.Get(t.tunnelServer + subdomain) + if err != nil { + return nil, err + } + defer res.Body.Close() + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, err + } + + if err = json.Unmarshal(body, &host); err != nil { + return nil, err + } + return &host, nil +} + +func (h *allocatedHost) OpenTunnel(conoidServer string, serviceServers []string) { + // Connect to server + localConn, err := net.Dial("tcp", conoidServer) + if err != nil { + log.Println("Error occured while tunneling:", err) + return + } + // Add to open connections + h.openConns <- localConn + + // Connect to remote host + // Parse url + pUrl, err := url.Parse(h.Url) + if err != nil { + log.Println(err) + return + } + remoteConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", pUrl.Host, h.Port)) + if err != nil { + log.Println("Error occured while", err) + return + } + // Add to open connections + h.openConns <- remoteConn + + // log.Println(localConn.LocalAddr().String(), "===>", serviceServers) + + // Establish a point-to-point connection between remote server and the local server + go func() { + for { + _, err = io.Copy(localConn, remoteConn) + if err != nil { + break + } + } + }() + + go func() { + for { + _, err = io.Copy(remoteConn, localConn) + if err != nil { + break + } + } + }() +} + +func (h *allocatedHost) SubDomain() string { + return h.Id +} + +func (h *allocatedHost) FullURL() string { + return h.Url +} + +func (h *allocatedHost) MaxConnectionCount() int { + return h.MaxConnCount +} + +func (h *allocatedHost) PortNumber() int { + return h.Port +} diff --git a/app/tools/tunnel_test.go b/app/tools/tunnel_test.go new file mode 100644 index 0000000..d270641 --- /dev/null +++ b/app/tools/tunnel_test.go @@ -0,0 +1,65 @@ +package tools_test + +import ( + "log" + "net" + "testing" + "time" + + "github.com/DeeStarks/conoid/app/tools" +) + +func TestTunnel(t *testing.T) { + tests := []struct { + name, svr string + }{ + {name: "abc", svr: ":30001"}, + {name: "cd", svr: ":30002"}, + {name: "agha", svr: ":30003"}, + {name: "abasdjjc", svr: ":30004"}, + {name: "a", svr: ":30005"}, + } + openConns := make(chan net.Conn, len(tests)*2) + + // Create main server + mSrv, err := net.Listen("tcp", "30000") + if err != nil { + log.Println(err) + return + } + + for _, tc := range tests { + tunnel := tools.NewTunnel(tc.name, openConns) + h, err := tunnel.AllocateHost() + if err != nil { + t.Error("Error allocating host:", err) + } + + // Open tunnel + svr, err := net.Listen("tcp", tc.svr) + if err != nil { + log.Println(err) + break + } + h.OpenTunnel(mSrv.Addr().String(), []string{svr.Addr().String()}) + } + + a := true + if a { + t.Error("All conns connected") + } + +L: + for { + select { + case conn := <-openConns: + err := conn.Close() + if err != nil { + t.Error("Error closing connection", err) + } + log.Printf("%s stopping", conn.LocalAddr().String()) + case <-time.After(time.Second * 5): + break L + } + } +} diff --git a/config/config.go b/config/config.go index e4f8386..4192afe 100644 --- a/config/config.go +++ b/config/config.go @@ -10,7 +10,8 @@ const ( VERSION = "0.0.1" // Network - TCP_PORT = 80 + TCP_PORT = 80 + MAX_CONN_COUNT = 100 ) var ( diff --git a/domain/schemas/default.go b/domain/schemas/default.go index 40a6201..91d31c0 100644 --- a/domain/schemas/default.go +++ b/domain/schemas/default.go @@ -12,4 +12,4 @@ var DefaultScript = ` tunnelled INTEGER, created_at NUMERIC NOT NULL ) -` \ No newline at end of file +`