From d54438d20358f4bc37e7165239ae7539705a5bb0 Mon Sep 17 00:00:00 2001 From: prathamesh0 <42446521+prathamesh0@users.noreply.github.com> Date: Thu, 5 Oct 2023 16:26:35 +0530 Subject: [PATCH] Improve filtering for paid RPC methods in payment proxy (#16) * Check for Content-Type and request fields for paid methods * Include the RPC method in response handler log --- paymentproxy/proxy.go | 80 ++++++++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 31 deletions(-) diff --git a/paymentproxy/proxy.go b/paymentproxy/proxy.go index 22eb03af7..9a9a34682 100644 --- a/paymentproxy/proxy.go +++ b/paymentproxy/proxy.go @@ -28,7 +28,8 @@ const ( CHANNEL_ID_VOUCHER_PARAM = "channelId" SIGNATURE_VOUCHER_PARAM = "signature" - VOUCHER_CONTEXT_ARG contextKey = "voucher" + VOUCHER_CONTEXT_ARG contextKey = "voucher" + RPC_METHOD_CONTEXT_ARG contextKey = "rpcMethod" ErrPayment = types.ConstError("payment error") ) @@ -113,36 +114,10 @@ func (p *PaymentProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { queryParams := r.URL.Query() requiresPayment := true + var rpcMethod string if p.enablePaidRpcMethods { - - var ReqBody struct { - Method string `json:"method"` - } - - bodyBytes, _ := io.ReadAll(r.Body) - // TODO: Check for content type - err := json.Unmarshal(bodyBytes, &ReqBody) - if err != nil { - p.handleError(w, r, createPaymentError(fmt.Errorf("could not unmarshall request body: %w", err))) - return - } - - slog.Debug("Serving RPC request", "method", ReqBody.Method) - - // Reassign request body as io.ReadAll consumes it - r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - - rpcMethod := ReqBody.Method - requiresPayment = false - - // Check if payment is required for RPC method - for _, paidRPCMethod := range paidRPCMethods { - if paidRPCMethod == rpcMethod { - requiresPayment = true - break - } - } + requiresPayment, rpcMethod = isPaymentRequired(r) } if requiresPayment { @@ -154,8 +129,9 @@ func (p *PaymentProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { removeVoucher(r) - // We add the voucher to the request context so we can access it in the response handler + // We add the voucher and rpcMethod to the request context so we can access them in the response handler r = r.WithContext(context.WithValue(r.Context(), VOUCHER_CONTEXT_ARG, v)) + r = r.WithContext(context.WithValue(r.Context(), RPC_METHOD_CONTEXT_ARG, rpcMethod)) } p.reverseProxy.ServeHTTP(w, r) @@ -192,7 +168,12 @@ func (p *PaymentProxy) handleDestinationResponse(r *http.Response) error { } cost := p.costPerByte * contentLength - slog.Debug("Request cost", "cost-per-byte", p.costPerByte, "response-length", contentLength, "cost", cost) + rpcMethod, ok := r.Request.Context().Value(RPC_METHOD_CONTEXT_ARG).(string) + if ok { + slog.Debug("Request cost", "cost-per-byte", p.costPerByte, "response-length", contentLength, "cost", cost, "method", rpcMethod) + } else { + slog.Debug("Request cost", "cost-per-byte", p.costPerByte, "response-length", contentLength, "cost", cost) + } s, err := p.nitroClient.ReceiveVoucher(v) if err != nil { @@ -252,6 +233,43 @@ func (p *PaymentProxy) Stop() error { return p.nitroClient.Close() } +// Helper method to parse request and determine whether it qualifies for a payment +// Payment is required for a request if: +// - "Content-Type" header is set to "application/json" +// - Request body has non-empty "jsonrpc" and "method" fields +func isPaymentRequired(r *http.Request) (bool, string) { + if r.Header.Get("Content-Type") != "application/json" { + return false, "" + } + + var ReqBody struct { + JsonRpc string `json:"jsonrpc"` + Method string `json:"method"` + } + bodyBytes, _ := io.ReadAll(r.Body) + + err := json.Unmarshal(bodyBytes, &ReqBody) + if err != nil || ReqBody.JsonRpc == "" || ReqBody.Method == "" { + return false, "" + } + + slog.Debug("Serving RPC request", "method", ReqBody.Method) + + // Reassign request body as io.ReadAll consumes it + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + rpcMethod := ReqBody.Method + + // Check if payment is required for RPC method + for _, paidRPCMethod := range paidRPCMethods { + if paidRPCMethod == rpcMethod { + return true, rpcMethod + } + } + + return false, "" +} + // parseVoucher takes in an a collection of query params and parses out a voucher. func parseVoucher(params url.Values) (payments.Voucher, error) { rawChId := params.Get(CHANNEL_ID_VOUCHER_PARAM)