Skip to content

Commit

Permalink
Improve filtering for paid RPC methods in payment proxy (#16)
Browse files Browse the repository at this point in the history
* Check for Content-Type and request fields for paid methods

* Include the RPC method in response handler log
  • Loading branch information
prathamesh0 authored and neerajvijay1997 committed Oct 11, 2023
1 parent ecdec9d commit d54438d
Showing 1 changed file with 49 additions and 31 deletions.
80 changes: 49 additions & 31 deletions paymentproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ const (
CHANNEL_ID_VOUCHER_PARAM = "channelId"
SIGNATURE_VOUCHER_PARAM = "signature"

VOUCHER_CONTEXT_ARG contextKey = "voucher"
VOUCHER_CONTEXT_ARG contextKey = "voucher"
RPC_METHOD_CONTEXT_ARG contextKey = "rpcMethod"

ErrPayment = types.ConstError("payment error")
)
Expand Down Expand Up @@ -113,36 +114,10 @@ func (p *PaymentProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {

queryParams := r.URL.Query()
requiresPayment := true
var rpcMethod string

if p.enablePaidRpcMethods {

var ReqBody struct {
Method string `json:"method"`
}

bodyBytes, _ := io.ReadAll(r.Body)
// TODO: Check for content type
err := json.Unmarshal(bodyBytes, &ReqBody)
if err != nil {
p.handleError(w, r, createPaymentError(fmt.Errorf("could not unmarshall request body: %w", err)))
return
}

slog.Debug("Serving RPC request", "method", ReqBody.Method)

// Reassign request body as io.ReadAll consumes it
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))

rpcMethod := ReqBody.Method
requiresPayment = false

// Check if payment is required for RPC method
for _, paidRPCMethod := range paidRPCMethods {
if paidRPCMethod == rpcMethod {
requiresPayment = true
break
}
}
requiresPayment, rpcMethod = isPaymentRequired(r)
}

if requiresPayment {
Expand All @@ -154,8 +129,9 @@ func (p *PaymentProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {

removeVoucher(r)

// We add the voucher to the request context so we can access it in the response handler
// We add the voucher and rpcMethod to the request context so we can access them in the response handler
r = r.WithContext(context.WithValue(r.Context(), VOUCHER_CONTEXT_ARG, v))
r = r.WithContext(context.WithValue(r.Context(), RPC_METHOD_CONTEXT_ARG, rpcMethod))
}

p.reverseProxy.ServeHTTP(w, r)
Expand Down Expand Up @@ -192,7 +168,12 @@ func (p *PaymentProxy) handleDestinationResponse(r *http.Response) error {
}
cost := p.costPerByte * contentLength

slog.Debug("Request cost", "cost-per-byte", p.costPerByte, "response-length", contentLength, "cost", cost)
rpcMethod, ok := r.Request.Context().Value(RPC_METHOD_CONTEXT_ARG).(string)
if ok {
slog.Debug("Request cost", "cost-per-byte", p.costPerByte, "response-length", contentLength, "cost", cost, "method", rpcMethod)
} else {
slog.Debug("Request cost", "cost-per-byte", p.costPerByte, "response-length", contentLength, "cost", cost)
}

s, err := p.nitroClient.ReceiveVoucher(v)
if err != nil {
Expand Down Expand Up @@ -252,6 +233,43 @@ func (p *PaymentProxy) Stop() error {
return p.nitroClient.Close()
}

// Helper method to parse request and determine whether it qualifies for a payment
// Payment is required for a request if:
// - "Content-Type" header is set to "application/json"
// - Request body has non-empty "jsonrpc" and "method" fields
func isPaymentRequired(r *http.Request) (bool, string) {
if r.Header.Get("Content-Type") != "application/json" {
return false, ""
}

var ReqBody struct {
JsonRpc string `json:"jsonrpc"`
Method string `json:"method"`
}
bodyBytes, _ := io.ReadAll(r.Body)

err := json.Unmarshal(bodyBytes, &ReqBody)
if err != nil || ReqBody.JsonRpc == "" || ReqBody.Method == "" {
return false, ""
}

slog.Debug("Serving RPC request", "method", ReqBody.Method)

// Reassign request body as io.ReadAll consumes it
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))

rpcMethod := ReqBody.Method

// Check if payment is required for RPC method
for _, paidRPCMethod := range paidRPCMethods {
if paidRPCMethod == rpcMethod {
return true, rpcMethod
}
}

return false, ""
}

// parseVoucher takes in an a collection of query params and parses out a voucher.
func parseVoucher(params url.Values) (payments.Voucher, error) {
rawChId := params.Get(CHANNEL_ID_VOUCHER_PARAM)
Expand Down

0 comments on commit d54438d

Please sign in to comment.