diff --git a/client.go b/client.go index d4d5a537..160b1c87 100644 --- a/client.go +++ b/client.go @@ -33,6 +33,9 @@ var ( // HTTPClient. The timeout includes connection time, any redirects, // and reading the response body. HTTPClientTimeout = 60 * time.Second + + InternalPoolSize = 10 + PingInverval = 20 * time.Second ) // Client represents a connection with the APNs @@ -53,7 +56,7 @@ type Client struct { // // If your use case involves multiple long-lived connections, consider using // the ClientManager, which manages clients for you. -func NewClient(certificate tls.Certificate) *Client { +func NewClient(certificate tls.Certificate, environment string) *Client { tlsConfig := &tls.Config{ Certificates: []tls.Certificate{certificate}, } @@ -66,6 +69,14 @@ func NewClient(certificate tls.Certificate) *Client { return tls.DialWithDialer(&net.Dialer{Timeout: TLSDialTimeout}, network, addr, cfg) }, } + poolMan, err := newPoolManager(transport, environment) + if err != nil { + fmt.Println(err) + } else { + for i := 0; i < InternalPoolSize; i++ { + poolMan.addNewConn() + } + } return &Client{ HTTPClient: &http.Client{ Transport: transport, diff --git a/client_manager.go b/client_manager.go index bb4bdf90..f8ffd352 100644 --- a/client_manager.go +++ b/client_manager.go @@ -29,7 +29,7 @@ type ClientManager struct { // Factory is the function which constructs clients if not found in the // manager. - Factory func(certificate tls.Certificate) *Client + Factory func(certificate tls.Certificate, environment string) *Client cache map[[sha1.Size]byte]*list.Element ll *list.List @@ -87,7 +87,7 @@ func (m *ClientManager) Add(client *Client) { // or if a Client has remained in the manager longer than MaxAge, Get will call // the ClientManager's Factory function, store the result in the manager if // non-nil, and return it. -func (m *ClientManager) Get(certificate tls.Certificate) *Client { +func (m *ClientManager) Get(certificate tls.Certificate, environment string) *Client { m.initInternals() m.mu.Lock() defer m.mu.Unlock() @@ -97,7 +97,7 @@ func (m *ClientManager) Get(certificate tls.Certificate) *Client { if ele, hit := m.cache[key]; hit { item := ele.Value.(*managerItem) if m.MaxAge != 0 && item.lastUsed.Before(now.Add(-m.MaxAge)) { - c := m.Factory(certificate) + c := m.Factory(certificate, environment) if c == nil { return nil } @@ -108,7 +108,7 @@ func (m *ClientManager) Get(certificate tls.Certificate) *Client { return item.client } - c := m.Factory(certificate) + c := m.Factory(certificate, environment) if c == nil { return nil } diff --git a/pool_manager.go b/pool_manager.go new file mode 100644 index 00000000..4246d208 --- /dev/null +++ b/pool_manager.go @@ -0,0 +1,104 @@ +package apns2 + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "reflect" + "strings" + "sync" + "time" + "unsafe" + + "golang.org/x/net/http2" + "golang.org/x/net/idna" +) + +type poolManager struct { + connPool http2.ClientConnPool + ctx context.Context + poolMu *sync.Mutex + u *url.URL +} + +// directly copied from source +func authorityAddr(scheme string, authority string) (addr string) { + host, port, err := net.SplitHostPort(authority) + if err != nil { + port = "443" + if scheme == "http" { + port = "80" + } + host = authority + } + if a, err := idna.ToASCII(host); err == nil { + host = a + } + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return host + ":" + port + } + return net.JoinHostPort(host, port) +} + +func newPoolManager(transport *http2.Transport, environment string) (*poolManager, error) { + transport.CloseIdleConnections() + rf := reflect.Indirect(reflect.ValueOf(transport)).FieldByName("connPoolOrDef") + connPool := *(*http2.ClientConnPool)(unsafe.Pointer(rf.UnsafeAddr())) + rf = reflect.Indirect(reflect.ValueOf(connPool)).FieldByName("mu") + poolMu := (*sync.Mutex)(unsafe.Pointer(rf.UnsafeAddr())) + u, err := url.Parse(environment) + if err != nil { + return nil, err + } + return &poolManager{ + connPool: connPool, + u: u, + poolMu: poolMu, + ctx: context.Background(), + }, nil +} + +func (pm *poolManager) addNewConn() error { + rff := reflect.Indirect(reflect.ValueOf(pm.connPool)).FieldByName("conns") + pm.poolMu.Lock() + internalConns := *(*map[string][]*http2.ClientConn)(unsafe.Pointer(rff.UnsafeAddr())) + for _, conns := range internalConns { + for _, conn := range conns { + rv := reflect.Indirect(reflect.ValueOf(conn)) + rf := rv.FieldByName("mu") + mu := (*sync.Mutex)(unsafe.Pointer(rf.UnsafeAddr())) + rf = rv.FieldByName("closed") + mu.Lock() + rf = reflect.Indirect(reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr()))) + rf.SetBool(true) + mu.Unlock() + defer func() { + mu.Lock() + rf.SetBool(false) + mu.Unlock() + }() + } + } + pm.poolMu.Unlock() + cc, err := pm.connPool.GetClientConn(&http.Request{Close: false}, authorityAddr(pm.u.Scheme, pm.u.Host)) + if err != nil { + return err + } + go pm.pingConn(cc) + return nil +} + +func (pm *poolManager) pingConn(cc *http2.ClientConn) { + for { + err := cc.Ping(pm.ctx) + if err != nil { + if err = pm.addNewConn(); err != nil { + fmt.Println(err) + } + return + } + time.Sleep(PingInverval) + } +}