Skip to content

Commit

Permalink
Merge pull request #360 from ngrok/nikolay/async-session
Browse files Browse the repository at this point in the history
support for async session creation
  • Loading branch information
nikolay-ngrok authored Apr 17, 2024
2 parents 2ada7fe + 494ab82 commit b45e7d1
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 17 deletions.
18 changes: 11 additions & 7 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"flag"
"fmt"
"net/http"
"net/url"
"os"
"strings"
Expand All @@ -37,7 +38,6 @@ import (
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/cache"
"sigs.k8s.io/controller-runtime/pkg/healthz"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
"sigs.k8s.io/controller-runtime/pkg/manager"
"sigs.k8s.io/controller-runtime/pkg/metrics/server"
Expand Down Expand Up @@ -215,8 +215,8 @@ func runController(ctx context.Context, opts managerOpts) error {
}
}

td, err := tunneldriver.New(
ctrl.Log.WithName("drivers").WithName("tunnel"), tunneldriver.TunnelDriverOpts{
td, err := tunneldriver.New(ctx, ctrl.Log.WithName("drivers").WithName("tunnel"),
tunneldriver.TunnelDriverOpts{
ServerAddr: opts.serverAddr,
Region: opts.region,
},
Expand Down Expand Up @@ -312,12 +312,16 @@ func runController(ctx context.Context, opts managerOpts) error {
}
//+kubebuilder:scaffold:builder

if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil {
return fmt.Errorf("error setting up health check: %w", err)
}
if err := mgr.AddReadyzCheck("readyz", healthz.Ping); err != nil {
if err := mgr.AddReadyzCheck("readyz", func(req *http.Request) error {
return td.Ready()
}); err != nil {
return fmt.Errorf("error setting up readyz check: %w", err)
}
if err := mgr.AddHealthzCheck("healthz", func(req *http.Request) error {
return td.Healthy()
}); err != nil {
return fmt.Errorf("error setting up health check: %w", err)
}

setupLog.Info("starting manager")
if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil {
Expand Down
110 changes: 100 additions & 10 deletions pkg/tunneldriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ import (
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"sync/atomic"

"github.com/go-logr/logr"
ingressv1alpha1 "github.com/ngrok/kubernetes-ingress-controller/api/ingress/v1alpha1"
Expand Down Expand Up @@ -43,7 +46,7 @@ const (

// TunnelDriver is a driver for creating and deleting ngrok tunnels
type TunnelDriver struct {
session ngrok.Session
session atomic.Pointer[sessionState]
tunnels map[string]ngrok.Tunnel
}

Expand All @@ -57,8 +60,14 @@ type TunnelDriverComments struct {
Gateway string `json:"gateway,omitempty"`
}

type sessionState struct {
session ngrok.Session
readyErr error
healthErr error
}

// New creates and initializes a new TunnelDriver
func New(logger logr.Logger, opts TunnelDriverOpts, tunnelComment *TunnelDriverComments) (*TunnelDriver, error) {
func New(ctx context.Context, logger logr.Logger, opts TunnelDriverOpts, tunnelComment *TunnelDriverComments) (*TunnelDriver, error) {
comments := []string{}

if tunnelComment != nil {
Expand Down Expand Up @@ -97,14 +106,90 @@ func New(logger logr.Logger, opts TunnelDriverOpts, tunnelComment *TunnelDriverC
connOpts = append(connOpts, ngrok.WithCA(caCerts))
}

session, err := ngrok.Connect(context.Background(), connOpts...)
if err != nil {
return nil, err
}
return &TunnelDriver{
session: session,
td := &TunnelDriver{
tunnels: make(map[string]ngrok.Tunnel),
}, nil
}

td.session.Store(&sessionState{
readyErr: fmt.Errorf("attempting to connect"),
})
connOpts = append(connOpts,
ngrok.WithConnectHandler(func(ctx context.Context, sess ngrok.Session) {
td.session.Store(&sessionState{
session: sess,
})
}),
ngrok.WithDisconnectHandler(func(ctx context.Context, sess ngrok.Session, err error) {
state := td.session.Load()

if state.session != nil {
// we have established session in the past, so record err only when it is going away
if err == nil {
td.session.Store(&sessionState{
healthErr: fmt.Errorf("session closed"),
})
}
return
}

if err == nil {
// session is disconnecting, do not override error
if state.healthErr == nil {
td.session.Store(&sessionState{
healthErr: fmt.Errorf("session closed"),
})
}
return
}

if state.healthErr != nil {
// we are already at a terminal error, just keep the first one
return
}

// we didn't have a session and we are seeing disconnect error
userErr := strings.HasPrefix(err.Error(), "authentication failed") && !strings.Contains(err.Error(), "internal server error")
if userErr {
// its a user error (e.g. authentication failure), so stop further
td.session.Store(&sessionState{
healthErr: err,
})
sess.Close()
} else {
// mark this as connecting error to return from readyz
td.session.Store(&sessionState{
readyErr: err,
})
}
}),
)
go ngrok.Connect(ctx, connOpts...)

return td, nil
}

func (td *TunnelDriver) Ready() error {
state := td.session.Load()
return state.readyErr
}

func (td *TunnelDriver) Healthy() error {
state := td.session.Load()
return state.healthErr
}

func (td *TunnelDriver) getSession() (ngrok.Session, error) {
state := td.session.Load()
switch {
case state.session != nil:
return state.session, nil
case state.healthErr != nil:
return nil, state.healthErr
case state.readyErr != nil:
return nil, state.readyErr
default:
return nil, fmt.Errorf("unexpected state")
}
}

// caCerts combines the system ca certs with a directory of custom ca certs
Expand Down Expand Up @@ -144,6 +229,11 @@ func caCerts() (*x509.CertPool, error) {
// CreateTunnel creates and starts a new tunnel in a goroutine. If a tunnel with the same name already exists,
// it will be stopped and replaced with a new tunnel unless the labels match.
func (td *TunnelDriver) CreateTunnel(ctx context.Context, name string, spec ingressv1alpha1.TunnelSpec) error {
session, err := td.getSession()
if err != nil {
return err
}

log := log.FromContext(ctx)

if tun, ok := td.tunnels[name]; ok {
Expand All @@ -156,7 +246,7 @@ func (td *TunnelDriver) CreateTunnel(ctx context.Context, name string, spec ingr
defer td.stopTunnel(context.Background(), tun)
}

tun, err := td.session.Listen(ctx, td.buildTunnelConfig(spec.Labels, spec.ForwardsTo, spec.AppProtocol))
tun, err := session.Listen(ctx, td.buildTunnelConfig(spec.Labels, spec.ForwardsTo, spec.AppProtocol))
if err != nil {
return err
}
Expand Down

0 comments on commit b45e7d1

Please sign in to comment.