Skip to content

Commit

Permalink
Use middleware for session handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Toby Padilla committed Jul 30, 2021
1 parent 9b1ba59 commit cf5a2ed
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 32 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func main() {
if err != nil {
panic(err)
}
s, err := NewServer(cfg.Port, cfg.KeyPath, tui.SessionHandler)
s, err := NewServer(cfg.Port, cfg.KeyPath, LoggingMiddleware(), BubbleTeaMiddleware(tui.SessionHandler))
if err != nil {
panic(err)
}
Expand Down
65 changes: 37 additions & 28 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,44 @@ import (
gossh "golang.org/x/crypto/ssh"
)

type SessionHandler func(ssh.Session) (tea.Model, error)
type Middleware func(ssh.Handler) ssh.Handler

type Server struct {
server *ssh.Server
key gossh.PublicKey
handler SessionHandler
func LoggingMiddleware() Middleware {
return func(sh ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
hpk := s.PublicKey() != nil
log.Printf("%s connect %v %v\n", s.RemoteAddr().String(), hpk, s.Command())
sh(s)
log.Printf("%s disconnect %v %v\n", s.RemoteAddr().String(), hpk, s.Command())
}
}
}

func NewServer(port int, keyPath string, handler SessionHandler) (*Server, error) {
s := &Server{
server: &ssh.Server{},
handler: handler,
func BubbleTeaMiddleware(bth func(ssh.Session) tea.Model) Middleware {
return func(sh ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
m := bth(s)
if m != nil {
p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithInput(s), tea.WithOutput(s))
err := p.Start()
if err != nil {
log.Printf("%s error %v: %s\n", s.RemoteAddr().String(), s.Command(), err)
}
}
sh(s)
}
}
}

type Server struct {
server *ssh.Server
key gossh.PublicKey
}

func NewServer(port int, keyPath string, mw ...Middleware) (*Server, error) {
s := &Server{server: &ssh.Server{}}
s.server.Version = "OpenSSH_7.6p1"
s.server.Addr = fmt.Sprintf(":%d", port)
s.server.Handler = s.sessionHandler
s.server.PasswordHandler = s.passHandler
s.server.PublicKeyHandler = s.authHandler
kps := strings.Split(keyPath, string(filepath.Separator))
Expand All @@ -42,28 +64,15 @@ func NewServer(port int, keyPath string, handler SessionHandler) (*Server, error
if err != nil {
return nil, err
}
h := func(s ssh.Session) {}
for _, m := range mw {
h = m(h)
}
s.server.Handler = h
return s, nil
}

func (srv *Server) sessionHandler(s ssh.Session) {
hpk := s.PublicKey() != nil
log.Printf("%s connect %v %v\n", s.RemoteAddr().String(), hpk, s.Command())
m, err := srv.handler(s)
if err != nil {
log.Printf("%s error %v %s\n", s.RemoteAddr().String(), hpk, err)
s.Exit(1)
return
}
if m != nil {
p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithInput(s), tea.WithOutput(s))
err = p.Start()
if err != nil {
log.Printf("%s error %v %s\n", s.RemoteAddr().String(), hpk, err)
s.Exit(1)
return
}
}
log.Printf("%s disconnect %v %v\n", s.RemoteAddr().String(), hpk, s.Command())
}

func (srv *Server) authHandler(ctx ssh.Context, key ssh.PublicKey) bool {
Expand Down
6 changes: 3 additions & 3 deletions tui/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ func (e errMsg) Error() string {
return e.err.Error()
}

func SessionHandler(s ssh.Session) (tea.Model, error) {
func SessionHandler(s ssh.Session) tea.Model {
pty, changes, active := s.Pty()
if !active {
return nil, fmt.Errorf("you need to do this from a terminal with PTY support")
return nil
}
return NewModel(pty.Window.Width, pty.Window.Height, changes), nil
return NewModel(pty.Window.Width, pty.Window.Height, changes)
}

type Model struct {
Expand Down

0 comments on commit cf5a2ed

Please sign in to comment.