Skip to content

Commit

Permalink
Merge pull request #816 from YaoZengzeng/port-forward
Browse files Browse the repository at this point in the history
feature: implement portforward for stream server
  • Loading branch information
allencloud authored Mar 7, 2018
2 parents 8b6bcff + 30c4ea9 commit 43c0cb3
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 15 deletions.
6 changes: 6 additions & 0 deletions cri/stream/constant/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,10 @@ const (
StreamTypeError = "error"
// StreamTypeResize is the value for streamType header for terminal resize stream
StreamTypeResize = "resize"

// PortHeader is the name of header that specifies the port being forwarded.
PortHeader = "port"
// PortForwardRequestIDHeader is the name of header that specifies a request ID
// used to associate the error and data streams for a single forwarded connection.
PortForwardRequestIDHeader = "requestID"
)
1 change: 0 additions & 1 deletion cri/stream/httpstream/spdy/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque
w.Header().Add(httpstream.HeaderUpgrade, HeaderSpdy31)
w.WriteHeader(http.StatusSwitchingProtocols)

// 获取底层的连接
conn, bufrw, err := hijacker.Hijack()
if err != nil {
logrus.Errorf("unable to upgrade: error hijacking response: %v", err)
Expand Down
253 changes: 253 additions & 0 deletions cri/stream/portforward/httpstream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
package portforward

import (
"fmt"
"net/http"
"strconv"
"sync"
"time"

"github.com/alibaba/pouch/cri/stream/constant"
"github.com/alibaba/pouch/cri/stream/httpstream"
"github.com/alibaba/pouch/cri/stream/httpstream/spdy"
"github.com/alibaba/pouch/pkg/collect"

"github.com/sirupsen/logrus"
)

// httpStreamReceived is the httpstream.NewStreamHandler for port
// forward streams. It checks each stream's port and stream type headers,
// rejecting any streams that with missing or invalid values. Each valid
// stream is sent to the streams channel.
func httpStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
// Make sure it has a valid port header.
portString := stream.Headers().Get(constant.PortHeader)
if len(portString) == 0 {
return fmt.Errorf("%q header is required", constant.PortHeader)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
}
if port < 1 {
return fmt.Errorf("port %q must be > 0", portString)
}

// Make sure it has a valid stream type header.
streamType := stream.Headers().Get(constant.StreamType)
if len(streamType) == 0 {
return fmt.Errorf("%q header is required", constant.StreamType)
}
if streamType != constant.StreamTypeError && streamType != constant.StreamTypeData {
return fmt.Errorf("invalid stream type %q", streamType)
}

streams <- stream
return nil
}
}

func handleHTTPStreams(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, idleTimeout, streamCreationTimeout time.Duration, supportedPortForwardProtocols []string) error {
_, err := httpstream.Handshake(w, req, supportedPortForwardProtocols)
// Negotiated protocol isn't currently used server side, but could be in the future.
if err != nil {
// Handshake writes the error to the client
return err
}
streamChan := make(chan httpstream.Stream, 1)

logrus.Infof("Upgrading port forward response")
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, httpStreamReceived(streamChan))
if conn == nil {
return fmt.Errorf("unable to upgrade connection")
}
defer conn.Close()

logrus.Infof("(conn=%p) setting forwarding streaming connection idle timeout to %v", conn, idleTimeout)
conn.SetIdleTimeout(idleTimeout)

h := &httpStreamHandler{
conn: conn,
streamChan: streamChan,
streamPairs: collect.NewSafeMap(),
streamCreationTimeout: streamCreationTimeout,
pod: podName,
forwarder: portForwarder,
}
h.run()

return nil
}

// httpStreamHandler is capable of processing multiple port forward
// requests over a single httpstream.Connection.
type httpStreamHandler struct {
conn httpstream.Connection
streamChan chan httpstream.Stream
streamPairs *collect.SafeMap
streamCreationTimeout time.Duration
pod string
forwarder PortForwarder
}

// getStreamPair returns a httpStreamPair for requestID. This creates a
// new pair if one does not yet exist for the requestID. The returned bool is
// true if the pair was created.
func (h *httpStreamHandler) getStreamPair(requestID string) (*httpStreamPair, bool) {
p, ok := h.streamPairs.Get(requestID).Result()
if ok {
logrus.Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
return p.(*httpStreamPair), false
}

logrus.Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)

pair := newPortForwardPair(requestID)
h.streamPairs.Put(requestID, pair)

return pair, true
}

// hasStreamPair returns a bool indicating if a stream pair for requestID exists.
func (h *httpStreamHandler) hasStreamPair(requestID string) bool {
_, ok := h.streamPairs.Get(requestID).Result()

return ok
}

// removeStreamPair removes the stream pair identified by requestID from streamPairs.
func (h *httpStreamHandler) removeStreamPair(requestID string) {
h.streamPairs.Remove(requestID)
}

// monitorStreamPair waits for the pair to receive both its error and data
// streams, or for the timeout to expire (whichever happens first), and then
// removes the pair.
func (h *httpStreamHandler) monitorStreamPair(p *httpStreamPair, timeout <-chan time.Time) {
select {
case <-timeout:
msg := fmt.Sprintf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID)
p.printError(msg)
case <-p.complete:
logrus.Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID)
}
h.removeStreamPair(p.requestID)
}

// requestID returns the request id for stream.
func (h *httpStreamHandler) requestID(stream httpstream.Stream) string {
requestID := stream.Headers().Get(constant.PortForwardRequestIDHeader)
if len(requestID) == 0 {
// TODO: support the connection come from the older client
// that isn't generating the request id header.
}

return requestID
}

// run is the main loop for the httpStreamHandler. It process new streams,
// invoking portForward for each complete stream pair. The loop exits
// when the httpstream.Connection is closed.
func (h *httpStreamHandler) run() {
logrus.Infof("(conn=%p) waiting for port forward streams", h.conn)

for {
select {
case <-h.conn.CloseChan():
logrus.Infof("(conn=%p) upgraded connection closed", h.conn)
return
case stream := <-h.streamChan:
requestID := h.requestID(stream)
streamType := stream.Headers().Get(constant.StreamType)
logrus.Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)

p, created := h.getStreamPair(requestID)
if created {
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
}
if complete, err := p.add(stream); err != nil {
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
p.printError(msg)
} else if complete {
go h.portForward(p)
}
}
}
}

// portForward invokes the httpStreamHandler's forwarder.PortForward
// function for the given stream pair.
func (h *httpStreamHandler) portForward(p *httpStreamPair) {
defer p.dataStream.Close()
defer p.errorStream.Close()

portString := p.dataStream.Headers().Get(constant.PortHeader)
port, _ := strconv.ParseInt(portString, 10, 32)

logrus.Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
err := h.forwarder.PortForward(h.pod, int32(port), p.dataStream)
logrus.Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)

if err != nil {
msg := fmt.Sprintf("error forwarding port %d to pod %s: %v", port, h.pod, err)
p.printError(msg)
}
}

// httpStreamPair represents the error and data streams for a port
// forwarding request.
type httpStreamPair struct {
lock sync.RWMutex
requestID string
dataStream httpstream.Stream
errorStream httpstream.Stream
complete chan struct{}
}

// newPortForwardPair creates a new httpStreamPair.
func newPortForwardPair(requestID string) *httpStreamPair {
return &httpStreamPair{
requestID: requestID,
complete: make(chan struct{}),
}
}

// add adds the stream to the httpStreamPair. If the pair already
// contains a stream for the new stream's type, an error is returned. add
// returns true if both the data and error streams for this pair have been
// received.
func (p *httpStreamPair) add(stream httpstream.Stream) (bool, error) {
p.lock.Lock()
defer p.lock.Unlock()

switch stream.Headers().Get(constant.StreamType) {
case constant.StreamTypeError:
if p.errorStream != nil {
return false, fmt.Errorf("error stream already assigned")
}
p.errorStream = stream
case constant.StreamTypeData:
if p.dataStream != nil {
return false, fmt.Errorf("data stream already assigned")
}
p.dataStream = stream
}

complete := p.errorStream != nil && p.dataStream != nil
if complete {
close(p.complete)
}

return complete, nil
}

// printError writes s to p.errorStream if p.errorStream has been set.
func (p *httpStreamPair) printError(s string) {
p.lock.RLock()
defer p.lock.RUnlock()
if p.errorStream != nil {
fmt.Fprint(p.errorStream, s)
}
}
32 changes: 32 additions & 0 deletions cri/stream/portforward/portforward.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package portforward

import (
"io"
"net/http"
"time"

"github.com/sirupsen/logrus"
)

// PortForwarder knows how to forward content from a data stream to/from a port
// in a pod.
type PortForwarder interface {
// PortForwarder copies data between a data stream and a port in a pod.
PortForward(name string, port int32, stream io.ReadWriteCloser) error
}

// ServePortForward handles a port forwarding request. A single request is
// kept alive as long as the client is still alive and the connection has not
// been timed out due to idleness. This function handles multiple forwarded
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
// handled by a single invocation of ServePortForward.
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, idleTimeout time.Duration, streamCreationTimeout time.Duration, supportedProtocols []string) {
// TODO: support web socket stream.
err := handleHTTPStreams(w, req, portForwarder, podName, idleTimeout, streamCreationTimeout, supportedProtocols)
if err != nil {
logrus.Errorf("failed to serve port forward: %v", err)
return
}

return
}
1 change: 0 additions & 1 deletion cri/stream/request_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ func (c *requestCache) Insert(req request) (token string, err error) {
if err != nil {
return "", err
}
// 将cache entry加入list
ele := c.ll.PushFront(&cacheEntry{token, req, time.Now().Add(CacheTTL)})

c.tokens[token] = ele
Expand Down
19 changes: 15 additions & 4 deletions cri/stream/server.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package stream

import (
"io"
"net/http"
"net/url"
"path"
"time"

"github.com/alibaba/pouch/cri/stream/constant"
"github.com/alibaba/pouch/cri/stream/portforward"
"github.com/alibaba/pouch/cri/stream/remotecommand"

"github.com/gorilla/mux"
Expand Down Expand Up @@ -57,7 +59,7 @@ type Runtime interface {
Attach() error

// PortForward forward port to pod.
PortForward() error
PortForward(name string, port int32, stream io.ReadWriteCloser) error
}

// Config defines the options used for running the stream server.
Expand Down Expand Up @@ -115,7 +117,7 @@ func NewServer(config Config, runtime Runtime) (Server, error) {
}{
{"/exec/{token}", s.serveExec},
{"/attach/{token}", s.serveAttach},
{"/portforward{token}", s.servePortForward},
{"/portforward/{token}", s.servePortForward},
}

r := mux.NewRouter()
Expand Down Expand Up @@ -209,12 +211,21 @@ func (s *server) servePortForward(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
return
}
_, ok = cachedRequest.(*runtimeapi.PortForwardRequest)
pf, ok := cachedRequest.(*runtimeapi.PortForwardRequest)
if !ok {
http.NotFound(w, r)
return
}
WriteError(grpc.Errorf(codes.NotFound, "servePortForward Has Not Been Completed Yet"), w)

portforward.ServePortForward(
w,
r,
s.runtime,
pf.PodSandboxId,
s.config.StreamIdleTimeout,
s.config.StreamCreationTimeout,
s.config.SupportedPortForwardProtocols,
)
}

func (s *server) buildURL(method string, token string) string {
Expand Down
2 changes: 1 addition & 1 deletion daemon/mgr/cri.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ type CriManager struct {

// NewCriManager creates a brand new cri manager.
func NewCriManager(config *config.Config, ctrMgr ContainerMgr, imgMgr ImageMgr) (*CriManager, error) {
streamServer, err := newStreamServer(streamServerAddress, streamServerPort)
streamServer, err := newStreamServer(ctrMgr, streamServerAddress, streamServerPort)
if err != nil {
return nil, fmt.Errorf("failed to create stream server for cri manager: %v", err)
}
Expand Down
Loading

0 comments on commit 43c0cb3

Please sign in to comment.