diff --git a/components/rpc/invoker/mosn/channel/httpchannel.go b/components/rpc/invoker/mosn/channel/httpchannel.go index 4c01de2f33..c3310688bf 100644 --- a/components/rpc/invoker/mosn/channel/httpchannel.go +++ b/components/rpc/invoker/mosn/channel/httpchannel.go @@ -23,18 +23,38 @@ import ( "net/http" "time" + "mosn.io/pkg/buffer" + "github.com/valyala/fasthttp" + "mosn.io/layotto/components/pkg/common" "mosn.io/layotto/components/rpc" - common "mosn.io/layotto/components/pkg/common" _ "mosn.io/mosn/pkg/stream/http" ) - // init is regist http channel func init() { RegistChannel("http", newHttpChannel) } +type hstate struct { + reader net.Conn + writer net.Conn +} + +func (h *hstate) onData(b buffer.IoBuffer) error { + data := b.Bytes() + if _, err := h.writer.Write(data); err != nil { + return err + } + b.Drain(len(data)) + return nil +} + +func (h *hstate) close() { + h.reader.Close() + h.writer.Close() +} + // httpChannel is Channel implement type httpChannel struct { pool *connPool @@ -42,20 +62,26 @@ type httpChannel struct { // newHttpChannel is create rpc.Channel by ChannelConfig func newHttpChannel(config ChannelConfig) (rpc.Channel, error) { - return &httpChannel{ - pool: newConnPool( - config.Size, - func() (net.Conn, error) { - local, remote := net.Pipe() - localTcpConn := &fakeTcpConn{c: local} - remoteTcpConn := &fakeTcpConn{c: remote} - if err := acceptFunc(remoteTcpConn, config.Listener); err != nil { - return nil, err - } - return localTcpConn, nil - }, nil, nil, nil, - ), - }, nil + hc := &httpChannel{} + hc.pool = newConnPool( + config.Size, + func() (net.Conn, error) { + local, remote := net.Pipe() + localTcpConn := &fakeTcpConn{c: local} + remoteTcpConn := &fakeTcpConn{c: remote} + if err := acceptFunc(remoteTcpConn, config.Listener); err != nil { + return nil, err + } + return localTcpConn, nil + }, func() interface{} { + s := &hstate{} + s.reader, s.writer = net.Pipe() + return s + }, + hc.onData, + hc.cleanup, + ) + return hc, nil } // Do is handle RPCRequest to RPCResponse @@ -69,8 +95,10 @@ func (h *httpChannel) Do(req *rpc.RPCRequest) (*rpc.RPCResponse, error) { return nil, err } + hstate := conn.state.(*hstate) deadline, _ := ctx.Deadline() - if err = conn.SetDeadline(deadline); err != nil { + if err = conn.SetWriteDeadline(deadline); err != nil { + hstate.close() h.pool.Put(conn, true) return nil, common.Error(common.UnavailebleCode, err.Error()) } @@ -79,27 +107,28 @@ func (h *httpChannel) Do(req *rpc.RPCRequest) (*rpc.RPCResponse, error) { defer fasthttp.ReleaseRequest(httpReq) if _, err = httpReq.WriteTo(conn); err != nil { + hstate.close() h.pool.Put(conn, true) return nil, common.Error(common.UnavailebleCode, err.Error()) } - httpResp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseResponse(httpResp) - if err = httpResp.Read(bufio.NewReader(conn)); err != nil { + httpResp := &fasthttp.Response{} + hstate.reader.SetReadDeadline(deadline) + + if err = httpResp.Read(bufio.NewReader(hstate.reader)); err != nil { + hstate.close() h.pool.Put(conn, true) return nil, common.Error(common.UnavailebleCode, err.Error()) } - body := httpResp.Body() h.pool.Put(conn, false) + body := httpResp.Body() if httpResp.StatusCode() != http.StatusOK { return nil, common.Errorf(common.UnavailebleCode, "http response code %d, body: %s", httpResp.StatusCode(), string(body)) } - data := make([]byte, len(body)) - copy(data, body) rpcResp := &rpc.RPCResponse{ ContentType: string(httpResp.Header.ContentType()), - Data: data, + Data: body, Header: map[string][]string{}, } httpResp.Header.VisitAll(func(key, value []byte) { @@ -131,3 +160,13 @@ func (h *httpChannel) constructReq(req *rpc.RPCRequest) *fasthttp.Request { httpReq.Header.Set("id", req.Id) return httpReq } + +func (h *httpChannel) onData(conn *wrapConn) error { + hstate := conn.state.(*hstate) + return hstate.onData(conn.buf) +} + +func (h *httpChannel) cleanup(conn *wrapConn, err error) { + hstate := conn.state.(*hstate) + hstate.close() +} diff --git a/components/rpc/invoker/mosn/channel/httpchannel_test.go b/components/rpc/invoker/mosn/channel/httpchannel_test.go index 5d9fe2f8a6..5bf3ca62b4 100644 --- a/components/rpc/invoker/mosn/channel/httpchannel_test.go +++ b/components/rpc/invoker/mosn/channel/httpchannel_test.go @@ -19,11 +19,13 @@ package channel import ( "bufio" "context" + "log" "net" "strconv" "strings" "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/valyala/fasthttp" @@ -51,6 +53,8 @@ func (ts *testhttpServer) readLoop(conn net.Conn) { switch content { case "close": return + case "timeout": + time.Sleep(2*time.Second) default: } @@ -58,6 +62,7 @@ func (ts *testhttpServer) readLoop(conn net.Conn) { resp.SetBody(req.Body()) if _, err := resp.WriteTo(conn); err != nil { + log.Println("test server err:", err.Error()) break } } @@ -96,6 +101,36 @@ func TestRenewHttpConn(t *testing.T) { assert.Equal(t, "hello", string(resp.Data)) } +func TestManyRequests(t *testing.T) { + startTestHttpServer() + + channel, err := newHttpChannel(ChannelConfig{Size: 1}) + assert.Nil(t, err) + + for i:=0;i<100;i++{ + req := &rpc.RPCRequest{Ctx: context.TODO(), Id: "foo", Method: "bar", Data: []byte("hello"), Timeout: 1000} + _, err = channel.Do(req) + assert.Nil(t, err) + } +} + +func TestResponseTimeout(t *testing.T) { + startTestHttpServer() + + channel, err := newHttpChannel(ChannelConfig{Size: 1}) + assert.Nil(t, err) + + req := &rpc.RPCRequest{Ctx: context.TODO(), Id: "foo", Method: "bar", Data: []byte("timeout"), Timeout: 1000} + _, err = channel.Do(req) + assert.Error(t, err) + + for i:=0;i<100;i++{ + req = &rpc.RPCRequest{Ctx: context.TODO(), Id: "foo", Method: "bar", Data: []byte("hello"), Timeout: 1000} + _, err = channel.Do(req) + assert.Nil(t, err) + } +} + func TestConcurrent(t *testing.T) { startTestHttpServer() diff --git a/pkg/grpc/api.go b/pkg/grpc/api.go index 54957e2049..e56c414380 100644 --- a/pkg/grpc/api.go +++ b/pkg/grpc/api.go @@ -230,9 +230,11 @@ func (a *api) InvokeService(ctx context.Context, in *runtimev1pb.InvokeServiceRe if resp.Header != nil { header := metadata.Pairs() for k, values := range resp.Header { - for _, v := range values { - header.Append(k, v) + // fix https://github.com/mosn/layotto/issues/285 + if strings.EqualFold("content-length", k) { + continue } + header.Set(k, values...) } grpc.SetHeader(ctx, header) }