diff --git a/cmd/examples/http/server.go b/cmd/examples/http/server.go index 957166d..97725b6 100644 --- a/cmd/examples/http/server.go +++ b/cmd/examples/http/server.go @@ -22,7 +22,6 @@ import ( "github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/jhttp" "github.com/creachadair/jrpc2/metrics" - "github.com/creachadair/jrpc2/server" ) var port = flag.Int("port", 0, "Service port") @@ -34,16 +33,17 @@ func main() { } // Start a local server with a single trivial method and bridge it to HTTP. - local := server.NewLocal(handler.Map{ + srv := jrpc2.NewServer(handler.Map{ "Ping": handler.New(func(ctx context.Context, msg ...string) string { return "OK: " + strings.Join(msg, ", ") }), - }, &server.LocalOptions{ - Server: &jrpc2.ServerOptions{ - Logger: log.New(os.Stderr, "[jhttp.Bridge] ", log.LstdFlags|log.Lshortfile), - Metrics: metrics.New(), - }, + }, &jrpc2.ServerOptions{ + Logger: log.New(os.Stderr, "[jhttp.Bridge] ", log.LstdFlags|log.Lshortfile), + Metrics: metrics.New(), }) - http.Handle("/rpc", jhttp.NewBridge(local.Client)) + bridge := jhttp.NewBridge(srv, nil) + defer bridge.Close() + + http.Handle("/rpc", bridge) log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *port), nil)) } diff --git a/jhttp/bridge.go b/jhttp/bridge.go index f537da0..5b25295 100644 --- a/jhttp/bridge.go +++ b/jhttp/bridge.go @@ -3,17 +3,16 @@ package jhttp import ( - "context" - "encoding/json" "fmt" "io/ioutil" "net/http" "strconv" "github.com/creachadair/jrpc2" + "github.com/creachadair/jrpc2/channel" ) -// A Bridge is a http.Handler that bridges requests to a JSON-RPC client. +// A Bridge is a http.Handler that bridges requests to a JSON-RPC server. // // The body of the HTTP POST request must contain the complete JSON-RPC request // message, encoded with Content-Type: application/json. Either a single @@ -26,20 +25,19 @@ import ( // If the HTTP request method is not "POST", the bridge reports 405 (Method Not // Allowed). If the Content-Type is not application/json, the bridge reports // 415 (Unsupported Media Type). -// -// The bridge attaches the inbound HTTP request to the context passed to the -// client, allowing an EncodeContext callback to retrieve state from the HTTP -// headers. Use jhttp.HTTPRequest to retrieve the request from the context. type Bridge struct { - cli *jrpc2.Client + ch channel.Channel + srv *jrpc2.Server + checkType func(string) bool } // ServeHTTP implements the required method of http.Handler. -func (b *Bridge) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (b Bridge) ServeHTTP(w http.ResponseWriter, req *http.Request) { if req.Method != "POST" { w.WriteHeader(http.StatusMethodNotAllowed) return - } else if req.Header.Get("Content-Type") != "application/json" { + } + if !b.checkType(req.Header.Get("Content-Type")) { w.WriteHeader(http.StatusUnsupportedMediaType) return } @@ -49,70 +47,38 @@ func (b *Bridge) ServeHTTP(w http.ResponseWriter, req *http.Request) { } } -func (b *Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error { +func (b Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error { body, err := ioutil.ReadAll(req.Body) if err != nil { return err } + + // The HTTP request requires a response, but the server will not reply if + // all the requests are notifications. Check whether we have any calls + // needing a response, and choose whether to wait for a reply based on that. jreq, err := jrpc2.ParseRequests(body) if err != nil { return err } - - // Because the bridge shares the JSON-RPC client between potentially many - // HTTP clients, we must virtualize the ID space for requests to preserve - // the HTTP client's assignment of IDs. - // - // To do this, we keep track of the inbound ID for each request so that we - // can map the responses back. This takes advantage of the fact that the - // *jrpc2.Client detangles batch order so that responses come back in the - // same order (modulo notifications) even if the server response did not - // preserve order. - - // Generate request specifications for the client. - var inboundID []string // for requests - spec := make([]jrpc2.Spec, len(jreq)) // requests & notifications - for i, req := range jreq { - spec[i] = jrpc2.Spec{ - Method: req.Method(), - Notify: req.IsNotification(), - } - if req.HasParams() { - var p json.RawMessage - req.UnmarshalParams(&p) - spec[i].Params = p - } - if !spec[i].Notify { - inboundID = append(inboundID, req.ID()) + var hasCall bool + for _, req := range jreq { + if !req.IsNotification() { + hasCall = true + break } } - - ctx := context.WithValue(req.Context(), httpReqKey{}, req) - rsps, err := b.cli.Batch(ctx, spec) - if err != nil { + if err := b.ch.Send(body); err != nil { return err } - // If all the requests were notifications, report success without responses. - if len(rsps) == 0 { + // If there are only notifications, report success without responses. + if !hasCall { w.WriteHeader(http.StatusNoContent) return nil } - // Otherwise, map the responses back to their original IDs, and marshal the - // response back into the body. - for i, rsp := range rsps { - rsp.SetID(inboundID[i]) - } - - // If the original request was a single message, make sure we encode the - // response the same way. - var reply []byte - if len(rsps) == 1 && (len(body) == 0 || body[0] != '[') { - reply, err = json.Marshal(rsps[0]) - } else { - reply, err = json.Marshal(rsps) - } + // Wait for the server to reply. + reply, err := b.ch.Recv() if err != nil { return err } @@ -122,23 +88,35 @@ func (b *Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error { return nil } -// Close shuts down the client associated with b and reports the result from -// its Close method. -func (b *Bridge) Close() error { return b.cli.Close() } +// Close closes the channel to the server, waits for the server to exit, and +// reports the exit status of the server. +func (b Bridge) Close() error { b.ch.Close(); return b.srv.Wait() } -// NewBridge constructs a new Bridge that dispatches requests through c. It is -// safe for the caller to continue to use c concurrently with the bridge, as -// long as it does not close the client. -func NewBridge(c *jrpc2.Client) *Bridge { return &Bridge{cli: c} } +// NewBridge starts srv constructs a new Bridge that dispatches HTTP requests +// to it. The server must be unstarted, or NewBridge will panic. The server +// will run until the bridge is closed. +func NewBridge(srv *jrpc2.Server, opts *BridgeOptions) Bridge { + cch, sch := channel.Direct() + return Bridge{ + ch: cch, + srv: srv.Start(sch), + checkType: opts.checkContentType(), + } +} -type httpReqKey struct{} +// BridgeOptions are optional settings for a Bridge. A nil pointer is ready for +// use and provides default values as described. +type BridgeOptions struct { + // If non-nil, this function is called to check whether the HTTP request's + // declared content-type is valid. If this function returns false, the + // request is rejected. If nil, the default check requires a content type of + // "application/json". + CheckContentType func(contentType string) bool +} -// HTTPRequest returns the HTTP request associated with ctx, or nil. The -// context passed to the JSON-RPC client by the Bridge will contain this value. -func HTTPRequest(ctx context.Context) *http.Request { - req, ok := ctx.Value(httpReqKey{}).(*http.Request) - if ok { - return req +func (o *BridgeOptions) checkContentType() func(string) bool { + if o == nil || o.CheckContentType == nil { + return func(ctype string) bool { return ctype == "application/json" } } - return nil + return o.CheckContentType } diff --git a/jhttp/example_test.go b/jhttp/example_test.go index 569c29d..a3d5551 100644 --- a/jhttp/example_test.go +++ b/jhttp/example_test.go @@ -9,21 +9,20 @@ import ( "net/http/httptest" "strings" + "github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/jhttp" - "github.com/creachadair/jrpc2/server" ) func Example() { // Set up a local server to demonstrate the API. - loc := server.NewLocal(handler.Map{ + srv := jrpc2.NewServer(handler.Map{ "Test": handler.New(func(ctx context.Context, ss ...string) (string, error) { return strings.Join(ss, " "), nil }), }, nil) - defer loc.Close() - b := jhttp.NewBridge(loc.Client) + b := jhttp.NewBridge(srv, nil) defer b.Close() hsrv := httptest.NewServer(b) diff --git a/jhttp/jhttp_test.go b/jhttp/jhttp_test.go index 04cb1d0..2566490 100644 --- a/jhttp/jhttp_test.go +++ b/jhttp/jhttp_test.go @@ -3,7 +3,7 @@ package jhttp_test import ( "context" "encoding/json" - "errors" + "fmt" "io" "io/ioutil" "net/http" @@ -14,30 +14,24 @@ import ( "github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2/handler" "github.com/creachadair/jrpc2/jhttp" - "github.com/creachadair/jrpc2/server" ) +var testService = handler.Map{ + "Test1": handler.New(func(ctx context.Context, ss []string) int { + return len(ss) + }), + "Test2": handler.New(func(ctx context.Context, req json.RawMessage) int { + return len(req) + }), +} + func TestBridge(t *testing.T) { // Set up a JSON-RPC server to answer requests bridged from HTTP. - loc := server.NewLocal(handler.Map{ - "Test": handler.New(func(ctx context.Context, ss ...string) (string, error) { - return strings.Join(ss, " "), nil - }), - }, &server.LocalOptions{ - Client: &jrpc2.ClientOptions{ - EncodeContext: func(ctx context.Context, _ string, p json.RawMessage) (json.RawMessage, error) { - if jhttp.HTTPRequest(ctx) == nil { - return nil, errors.New("no HTTP request in context") - } - return p, nil - }, - }, - }) - defer loc.Close() + srv := jrpc2.NewServer(testService, nil) // Bridge HTTP to the JSON-RPC server. - b := jhttp.NewBridge(loc.Client) - defer b.Close() + b := jhttp.NewBridge(srv, nil) + defer checkClose(t, b) // Create an HTTP test server to call into the bridge. hsrv := httptest.NewServer(b) @@ -48,7 +42,7 @@ func TestBridge(t *testing.T) { rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{ "jsonrpc": "2.0", "id": 1, - "method": "Test", + "method": "Test1", "params": ["a", "foolish", "consistency", "is", "the", "hobgoblin"] } `)) @@ -62,7 +56,7 @@ func TestBridge(t *testing.T) { t.Errorf("Reading POST body: %v", err) } - const want = `{"jsonrpc":"2.0","id":1,"result":"a foolish consistency is the hobgoblin"}` + const want = `{"jsonrpc":"2.0","id":1,"result":6}` if got := string(body); got != want { t.Errorf("POST body: got %#q, want %#q", got, want) } @@ -71,8 +65,8 @@ func TestBridge(t *testing.T) { // Verify that the bridge will accept a batch. t.Run("PostBatchOK", func(t *testing.T) { rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`[ - {"jsonrpc":"2.0", "id": 3, "method": "Test", "params": ["first"]}, - {"jsonrpc":"2.0", "id": 7, "method": "Test", "params": ["among", "equals"]} + {"jsonrpc":"2.0", "id": 3, "method": "Test1", "params": ["first"]}, + {"jsonrpc":"2.0", "id": 7, "method": "Test1", "params": ["among", "equals"]} ] `)) if err != nil { @@ -85,8 +79,8 @@ func TestBridge(t *testing.T) { t.Errorf("Reading POST body: %v", err) } - const want = `[{"jsonrpc":"2.0","id":3,"result":"first"},` + - `{"jsonrpc":"2.0","id":7,"result":"among equals"}]` + const want = `[{"jsonrpc":"2.0","id":3,"result":1},` + + `{"jsonrpc":"2.0","id":7,"result":2}]` if got := string(body); got != want { t.Errorf("POST body: got %#q, want %#q", got, want) } @@ -159,16 +153,47 @@ func TestBridge(t *testing.T) { }) } +// Verify that the content-type check hook works. +func TestBridge_contentTypeCheck(t *testing.T) { + srv := jrpc2.NewServer(testService, nil) + + b := jhttp.NewBridge(srv, &jhttp.BridgeOptions{ + CheckContentType: func(ctype string) bool { + return ctype == "application/octet-stream" + }, + }) + defer checkClose(t, b) + + hsrv := httptest.NewServer(b) + defer hsrv.Close() + + const reqTemplate = `{"jsonrpc":"2.0","id":%q,"method":"Test1","params":["a","b","c"]}` + t.Run("ContentTypeOK", func(t *testing.T) { + rsp, err := http.Post(hsrv.URL, "application/octet-stream", + strings.NewReader(fmt.Sprintf(reqTemplate, "ok"))) + if err != nil { + t.Fatalf("POST request failed: %v", err) + } else if got, want := rsp.StatusCode, http.StatusOK; got != want { + t.Errorf("POST response code: got %v, want %v", got, want) + } + }) + + t.Run("ContentTypeBad", func(t *testing.T) { + rsp, err := http.Post(hsrv.URL, "text/plain", + strings.NewReader(fmt.Sprintf(reqTemplate, "bad"))) + if err != nil { + t.Fatalf("POST request failed: %v", err) + } else if got, want := rsp.StatusCode, http.StatusUnsupportedMediaType; got != want { + t.Errorf("POST response code: got %v, want %v", got, want) + } + }) +} + func TestChannel(t *testing.T) { - loc := server.NewLocal(handler.Map{ - "Test": handler.New(func(ctx context.Context, arg json.RawMessage) (int, error) { - return len(arg), nil - }), - }, nil) - defer loc.Close() - - b := jhttp.NewBridge(loc.Client) - defer b.Close() + srv := jrpc2.NewServer(testService, nil) + + b := jhttp.NewBridge(srv, nil) + defer checkClose(t, b) hsrv := httptest.NewServer(b) defer hsrv.Close() @@ -191,7 +216,7 @@ func TestChannel(t *testing.T) { for _, test := range tests { var got int - if err := cli.CallResult(ctx, "Test", test.params, &got); err != nil { + if err := cli.CallResult(ctx, "Test2", test.params, &got); err != nil { t.Errorf("Call Test(%v): unexpected error: %v", test.params, err) } else if got != test.want { t.Errorf("Call Test(%v): got %d, want %d", test.params, got, test.want) @@ -225,3 +250,10 @@ func (c counter) Do(req *http.Request) (*http.Response, error) { defer func() { *c.z++ }() return c.c.Do(req) } + +func checkClose(t *testing.T, c io.Closer) { + t.Helper() + if err := c.Close(); err != nil { + t.Errorf("Error in Close: %v", err) + } +}