Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extract endlessServer to support TCP #10

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 136 additions & 35 deletions endless.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/tls"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -60,8 +61,18 @@ func init() {
DefaultHammerTime = 60 * time.Second
}

type Server interface {
Serve(l net.Listener) error
}

type CloseFunc func() error

func (fn CloseFunc) Close() error { return fn() }

type endlessServer struct {
http.Server
Server
io.Closer
Addr string
EndlessListener net.Listener
SignalHooks map[int]map[os.Signal][]func()
tlsInnerListener *endlessListener
Expand All @@ -71,11 +82,7 @@ type endlessServer struct {
state uint8
}

/*
NewServer returns an intialized endlessServer Object. Calling Serve on it will
actually "start" the server.
*/
func NewServer(addr string, handler http.Handler) (srv *endlessServer) {
func newEndlessServer(addr string, server Server, closer io.Closer) (srv *endlessServer) {
runningServerReg.Lock()
defer runningServerReg.Unlock()
if !flag.Parsed() {
Expand All @@ -90,6 +97,9 @@ func NewServer(addr string, handler http.Handler) (srv *endlessServer) {
}

srv = &endlessServer{
Server: server,
Closer: closer,
Addr: addr,
wg: sync.WaitGroup{},
sigChan: make(chan os.Signal),
isChild: isChild,
Expand All @@ -114,18 +124,33 @@ func NewServer(addr string, handler http.Handler) (srv *endlessServer) {
state: STATE_INIT,
}

srv.Server.Addr = addr
srv.Server.ReadTimeout = DefaultReadTimeOut
srv.Server.WriteTimeout = DefaultWriteTimeOut
srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes
srv.Server.Handler = handler

runningServersOrder = append(runningServersOrder, addr)
runningServers[addr] = srv

return
}

/*
NewServer returns an intialized endlessServer Object. Calling Serve on it will
actually "start" the server.
*/
func NewServer(addr string, handler http.Handler) *endlessServer {
srv := &http.Server{
Addr: addr,
ReadTimeout: DefaultReadTimeOut,
WriteTimeout: DefaultWriteTimeOut,
MaxHeaderBytes: DefaultMaxHeaderBytes,
Handler: handler,
}

return newEndlessServer(addr, srv, CloseFunc(func() error {
// disable keep-alives on existing connections
srv.SetKeepAlivesEnabled(false)

return nil
}))
}

/*
ListenAndServe listens on the TCP network address addr and then calls Serve
with handler to handle requests on incoming connections. Handler is typically
Expand All @@ -148,6 +173,79 @@ func ListenAndServeTLS(addr string, certFile string, keyFile string, handler htt
return server.ListenAndServeTLS(certFile, keyFile)
}

type Handler interface {
Serve(net.Conn)
}

type HandleFunc func(net.Conn)

func (fn HandleFunc) Serve(conn net.Conn) { fn(conn) }

func ListenAndServeTCP(addr string, handler Handler) error {
server := NewTcpServer(addr, handler)
return server.ListenAndServe()
}

type tcpServer struct {
handler Handler
}

func NewTcpServer(addr string, handler Handler) *endlessServer {
return newEndlessServer(addr, &tcpServer{handler}, CloseFunc(func() error { return nil }))
}

func (srv *tcpServer) Serve(l net.Listener) error {
defer l.Close()
var tempDelay time.Duration // how long to sleep on accept failure
for {
rw, e := l.Accept()
if e != nil {
if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
log.Printf("%d %s tcp: Accept error: %v; retrying in %v", os.Getpid(), l.(*endlessListener).Addr(), e, tempDelay)
time.Sleep(tempDelay)
continue
}
return e
}
tempDelay = 0

go srv.handler.Serve(rw)
}
}

type tlsServer struct {
tcpServer

TLSConfig *tls.Config
}

func NewTlsServer(addr string, handler Handler, config *tls.Config) *endlessServer {
return newEndlessServer(addr, &tlsServer{tcpServer{handler}, config}, CloseFunc(func() error { return nil }))
}

func (srv *tlsServer) Serve(l net.Listener) error {
return srv.tcpServer.Serve(tls.NewListener(l, srv.TLSConfig))
}

func (srv *endlessServer) TLSConfig() *tls.Config {
switch t := srv.Server.(type) {
case *http.Server:
return t.TLSConfig
case *tlsServer:
return t.TLSConfig
}

return nil
}

/*
Serve accepts incoming HTTP connections on the listener l, creating a new
service goroutine for each. The service goroutines read requests and then call
Expand All @@ -159,18 +257,18 @@ sync.Waitgroup so that all outstanding connections can be served before shutting
down the server.
*/
func (srv *endlessServer) Serve() (err error) {
defer log.Println(syscall.Getpid(), "Serve() returning...")
defer log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "Serve() returning...")
srv.state = STATE_RUNNING
err = srv.Server.Serve(srv.EndlessListener)
log.Println(syscall.Getpid(), "Waiting for connections to finish...")
log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "Waiting for connections to finish...")
srv.wg.Wait()
srv.state = STATE_TERMINATE
return
}

/*
ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
to handle requests on incoming connections. If srv.Addr is blank, ":http" is
ListenAndServe listens on the TCP network address srv.EndlessListener.Addr() and then calls Serve
to handle requests on incoming connections. If srv.EndlessListener.Addr() is blank, ":http" is
used.
*/
func (srv *endlessServer) ListenAndServe() (err error) {
Expand All @@ -193,20 +291,20 @@ func (srv *endlessServer) ListenAndServe() (err error) {
syscall.Kill(syscall.Getppid(), syscall.SIGTERM)
}

log.Println(syscall.Getpid(), srv.Addr)
log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "ListenAndServe")
return srv.Serve()
}

/*
ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
ListenAndServeTLS listens on the TCP network address srv.EndlessListener.Addr() and then calls
Serve to handle requests on incoming TLS connections.

Filenames containing a certificate and matching private key for the server must
be provided. If the certificate is signed by a certificate authority, the
certFile should be the concatenation of the server's certificate followed by the
CA's certificate.

If srv.Addr is blank, ":https" is used.
If srv.EndlessListener.Addr() is blank, ":https" is used.
*/
func (srv *endlessServer) ListenAndServeTLS(certFile, keyFile string) (err error) {
addr := srv.Addr
Expand All @@ -215,9 +313,11 @@ func (srv *endlessServer) ListenAndServeTLS(certFile, keyFile string) (err error
}

config := &tls.Config{}
if srv.TLSConfig != nil {
*config = *srv.TLSConfig

if srv.TLSConfig() != nil {
*config = *srv.TLSConfig()
}

if config.NextProtos == nil {
config.NextProtos = []string{"http/1.1"}
}
Expand All @@ -243,7 +343,7 @@ func (srv *endlessServer) ListenAndServeTLS(certFile, keyFile string) (err error
syscall.Kill(syscall.Getppid(), syscall.SIGTERM)
}

log.Println(syscall.Getpid(), srv.Addr)
log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "ListenAndServeTLS")
return srv.Serve()
}

Expand Down Expand Up @@ -298,26 +398,26 @@ func (srv *endlessServer) handleSignals() {
srv.signalHooks(PRE_SIGNAL, sig)
switch sig {
case syscall.SIGHUP:
log.Println(pid, "Received SIGHUP. forking.")
log.Println(pid, srv.EndlessListener.Addr(), "Received SIGHUP. forking.")
err := srv.fork()
if err != nil {
log.Println("Fork err:", err)
}
case syscall.SIGUSR1:
log.Println(pid, "Received SIGUSR1.")
log.Println(pid, srv.EndlessListener.Addr(), "Received SIGUSR1.")
case syscall.SIGUSR2:
log.Println(pid, "Received SIGUSR2.")
log.Println(pid, srv.EndlessListener.Addr(), "Received SIGUSR2.")
srv.hammerTime(0 * time.Second)
case syscall.SIGINT:
log.Println(pid, "Received SIGINT.")
log.Println(pid, srv.EndlessListener.Addr(), "Received SIGINT.")
srv.shutdown()
case syscall.SIGTERM:
log.Println(pid, "Received SIGTERM.")
log.Println(pid, srv.EndlessListener.Addr(), "Received SIGTERM.")
srv.shutdown()
case syscall.SIGTSTP:
log.Println(pid, "Received SIGTSTP.")
log.Println(pid, srv.EndlessListener.Addr(), "Received SIGTSTP.")
default:
log.Printf("Received %v: nothing i care about...\n", sig)
log.Printf("%d %s Received %v: nothing i care about...\n", pid, srv.EndlessListener.Addr(), sig)
}
srv.signalHooks(POST_SIGNAL, sig)
}
Expand Down Expand Up @@ -347,11 +447,12 @@ func (srv *endlessServer) shutdown() {
if DefaultHammerTime >= 0 {
go srv.hammerTime(DefaultHammerTime)
}
// disable keep-alives on existing connections
srv.SetKeepAlivesEnabled(false)

srv.Close()

err := srv.EndlessListener.Close()
if err != nil {
log.Println(syscall.Getpid(), "Listener.Close() error:", err)
log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "Listener.Close() error:", err)
} else {
log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "Listener closed.")
}
Expand Down Expand Up @@ -405,12 +506,12 @@ func (srv *endlessServer) fork() (err error) {
switch srvPtr.EndlessListener.(type) {
case *endlessListener:
// normal listener
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.EndlessListener.(*endlessListener).File()
files[socketPtrOffsetMap[srvPtr.Addr]] = srvPtr.EndlessListener.(*endlessListener).File()
default:
// tls listener
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
files[socketPtrOffsetMap[srvPtr.Addr]] = srvPtr.tlsInnerListener.File()
}
orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
orderArgs[socketPtrOffsetMap[srvPtr.Addr]] = srvPtr.Addr
}

// log.Println(files)
Expand Down
67 changes: 67 additions & 0 deletions examples/echo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package main

import (
"io"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"sync"

"github.com/flier/endless"
"github.com/gorilla/mux"
)

func handler(w http.ResponseWriter, r *http.Request) {
buf, _ := ioutil.ReadAll(r.Body)

w.Write(buf)
}

func main() {
var wg sync.WaitGroup

wg.Add(2)

go func() {
defer wg.Done()

endless.ListenAndServeTCP("localhost:8007", endless.HandleFunc(func(conn net.Conn) {
defer conn.Close()

var buf [4096]byte

for {
if n, err := conn.Read(buf[:]); err != nil {
if err != io.EOF {
log.Printf("error, %s", err)
}

break
} else if _, err := conn.Write(buf[:n]); err != nil {
log.Printf("error, %s", err)

break
}
}
}))
}()

go func() {
defer wg.Done()

mux1 := mux.NewRouter()
mux1.HandleFunc("/", handler).Methods("POST")

if err := endless.ListenAndServe("localhost:8008", mux1); err != nil {
log.Println(err)
} else {
log.Println("Server on 8007 stopped")
}
}()

wg.Wait()

os.Exit(0)
}