From d4fd6ceb884036a6fd5956fbafa8d840b477e78f Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Sun, 12 Apr 2020 12:47:39 -0400 Subject: [PATCH 01/18] refactor WalletService, wallet errors --- Makefile | 4 + api/handlers.go | 15 ---- api/routes.go | 5 +- app/proxy/client.go | 19 +--- app/proxy/client_test.go | 26 ++---- app/proxy/handlers.go | 2 +- app/proxy/proxy_test.go | 2 +- app/publish/handler_test.go | 10 +-- app/publish/testing.go | 150 ++++++++++++++++++++++++++++++- app/sdkrouter/sdkrouter.go | 16 +++- app/sdkrouter/sdkrouter_test.go | 8 +- app/users/remote.go | 15 ++-- app/users/users.go | 108 ++++++++-------------- app/users/users_test.go | 6 +- internal/lbrynet/errors.go | 115 +++++++----------------- internal/lbrynet/errors_test.go | 10 +-- internal/lbrynet/lbrynet.go | 101 ++++++++++----------- internal/lbrynet/lbrynet_test.go | 28 +++--- internal/lbrynet/testing.go | 147 ------------------------------ util/wallet/wallet.go | 10 --- 20 files changed, 327 insertions(+), 470 deletions(-) delete mode 100644 api/handlers.go delete mode 100644 internal/lbrynet/testing.go delete mode 100644 util/wallet/wallet.go diff --git a/Makefile b/Makefile index 7933f321..f6faf3e6 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,10 @@ prepare_test: test: go test -cover ./... +.PHONY: test_race +test_race: + go test -race -gcflags=all=-d=checkptr=0 ./... + .PHONY: test_circleci test_circleci: scripts/wait_for_wallet.sh diff --git a/api/handlers.go b/api/handlers.go deleted file mode 100644 index b3b38f02..00000000 --- a/api/handlers.go +++ /dev/null @@ -1,15 +0,0 @@ -package api - -import ( - "net/http" - - "github.com/lbryio/lbrytv/config" - "github.com/lbryio/lbrytv/internal/monitor" -) - -var logger = monitor.NewModuleLogger("api") - -// Index serves a blank home page -func Index(w http.ResponseWriter, req *http.Request) { - http.Redirect(w, req, config.GetProjectURL(), http.StatusSeeOther) -} diff --git a/api/routes.go b/api/routes.go index 3e32005d..09d02cd7 100644 --- a/api/routes.go +++ b/api/routes.go @@ -10,6 +10,7 @@ import ( "github.com/lbryio/lbrytv/app/publish" "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/app/users" + "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/metrics" "github.com/lbryio/lbrytv/internal/status" @@ -28,7 +29,9 @@ func InstallRoutes(proxyService *proxy.ProxyService, r *mux.Router) { r.Use(methodTimer) - r.HandleFunc("/", Index) + r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + http.Redirect(w, req, config.GetProjectURL(), http.StatusSeeOther) + }) v1Router := r.PathPrefix("/api/v1").Subrouter() v1Router.HandleFunc("/proxy", proxyHandler.HandleOptions).Methods(http.MethodOptions) diff --git a/app/proxy/client.go b/app/proxy/client.go index cb32baca..e293ff75 100644 --- a/app/proxy/client.go +++ b/app/proxy/client.go @@ -41,7 +41,6 @@ func NewClient(endpoint string, wallet string, timeout time.Duration) LbrynetCli func (c Client) Call(q *Query) (*jsonrpc.RPCResponse, error) { var ( - i int r *jsonrpc.RPCResponse err error duration float64 @@ -50,7 +49,7 @@ func (c Client) Call(q *Query) (*jsonrpc.RPCResponse, error) { callMetrics := metrics.ProxyCallDurations.WithLabelValues(q.Method(), c.endpoint) failureMetrics := metrics.ProxyCallFailedDurations.WithLabelValues(q.Method(), c.endpoint) - for i = 0; i < walletLoadRetries; i++ { + for i := 0; i < walletLoadRetries; i++ { start := time.Now() r, err = c.rpcClient.CallRaw(q.Request) @@ -102,21 +101,9 @@ func (c Client) Call(q *Query) (*jsonrpc.RPCResponse, error) { } func (c *Client) isWalletNotLoaded(r *jsonrpc.RPCResponse) bool { - if r.Error != nil { - wErr := lbrynet.NewWalletError(0, errors.New(r.Error.Message)) - if errors.As(wErr, &lbrynet.WalletNotLoaded{}) { - return true - } - } - return false + return r.Error != nil && errors.Is(lbrynet.NewWalletError(0, errors.New(r.Error.Message)), lbrynet.ErrWalletNotLoaded) } func (c *Client) isWalletAlreadyLoaded(r *jsonrpc.RPCResponse) bool { - if r.Error != nil { - wErr := lbrynet.NewWalletError(0, errors.New(r.Error.Message)) - if errors.As(wErr, &lbrynet.WalletAlreadyLoaded{}) { - return true - } - } - return false + return r.Error != nil && errors.Is(lbrynet.NewWalletError(0, errors.New(r.Error.Message)), lbrynet.ErrWalletAlreadyLoaded) } diff --git a/app/proxy/client_test.go b/app/proxy/client_test.go index 779b1fba..24f0d4d8 100644 --- a/app/proxy/client_test.go +++ b/app/proxy/client_test.go @@ -8,7 +8,6 @@ import ( "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/lbrynet" - "github.com/lbryio/lbrytv/util/wallet" "github.com/stretchr/testify/require" "github.com/ybbus/jsonrpc" @@ -52,23 +51,19 @@ func (c MockRPCClient) CallBatchRaw(requests jsonrpc.RPCRequests) (jsonrpc.RPCRe } func TestClientCallDoesReloadWallet(t *testing.T) { - var ( - r *jsonrpc.RPCResponse - ) - rand.Seed(time.Now().UnixNano()) dummyUserID := rand.Intn(100) rt := sdkrouter.New(config.GetLbrynetServers()) - _, wid, _ := lbrynet.InitializeWallet(rt, dummyUserID) - _, err := lbrynet.WalletRemove(rt, dummyUserID) + wid, _ := lbrynet.InitializeWallet(rt, dummyUserID) + err := lbrynet.UnloadWallet(rt, dummyUserID) require.NoError(t, err) c := NewClient(rt.GetServer(wid).Address, wid, time.Second*1) q, _ := NewQuery(newRawRequest(t, "wallet_balance", nil)) q.SetWalletID(wid) - r, err = c.Call(q) + r, err := c.Call(q) // err = json.Unmarshal(result, response) require.NoError(t, err) @@ -76,17 +71,13 @@ func TestClientCallDoesReloadWallet(t *testing.T) { } func TestClientCallDoesNotReloadWalletAfterOtherErrors(t *testing.T) { - var ( - r *jsonrpc.RPCResponse - ) - rand.Seed(time.Now().UnixNano()) - wid := wallet.MakeID(rand.Intn(100)) + walletID := sdkrouter.WalletID(rand.Intn(100)) mc := NewMockRPCClient() c := &Client{rpcClient: mc} q, _ := NewQuery(newRawRequest(t, "wallet_balance", nil)) - q.SetWalletID(wid) + q.SetWalletID(walletID) mc.AddNextResponse(&jsonrpc.RPCResponse{ JSONRPC: "2.0", @@ -107,12 +98,8 @@ func TestClientCallDoesNotReloadWalletAfterOtherErrors(t *testing.T) { } func TestClientCallDoesNotReloadWalletIfAlreadyLoaded(t *testing.T) { - var ( - r *jsonrpc.RPCResponse - ) - rand.Seed(time.Now().UnixNano()) - wid := wallet.MakeID(rand.Intn(100)) + wid := sdkrouter.WalletID(rand.Intn(100)) mc := NewMockRPCClient() c := &Client{rpcClient: mc} @@ -137,6 +124,7 @@ func TestClientCallDoesNotReloadWalletIfAlreadyLoaded(t *testing.T) { }) r, err := c.Call(q) + require.NoError(t, err) require.Nil(t, r.Error) require.Equal(t, `"99999.00"`, r.Result) diff --git a/app/proxy/handlers.go b/app/proxy/handlers.go index 2ab1c6da..0bc65292 100644 --- a/app/proxy/handlers.go +++ b/app/proxy/handlers.go @@ -43,8 +43,8 @@ func (rh *RequestHandler) Handle(w http.ResponseWriter, r *http.Request) { if err != nil || !methodInList(q.Method(), relaxedMethods) { retriever := users.NewWalletService(rh.SDKRouter) auth := users.NewAuthenticator(retriever) - walletID, err = auth.GetWalletID(r) + walletID, err = auth.GetWalletID(r) if err != nil { responses.JSONRPCError(w, err.Error(), ErrAuthFailed) monitor.CaptureRequestError(err, r, w) diff --git a/app/proxy/proxy_test.go b/app/proxy/proxy_test.go index 19801b77..f7c4b8b1 100644 --- a/app/proxy/proxy_test.go +++ b/app/proxy/proxy_test.go @@ -140,7 +140,7 @@ func TestCallerCallWalletBalance(t *testing.T) { dummyUserID := rand.Intn(10^6-10^3) + 10 ^ 3 rt := sdkrouter.New(config.GetLbrynetServers()) - _, wid, err := lbrynet.InitializeWallet(rt, dummyUserID) + wid, err := lbrynet.InitializeWallet(rt, dummyUserID) require.NoError(t, err) svc := NewService(Opts{SDKRouter: rt}) diff --git a/app/publish/handler_test.go b/app/publish/handler_test.go index 1e8a2683..6b0facfa 100644 --- a/app/publish/handler_test.go +++ b/app/publish/handler_test.go @@ -13,8 +13,6 @@ import ( "testing" "github.com/lbryio/lbrytv/app/users" - "github.com/lbryio/lbrytv/internal/lbrynet" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ybbus/jsonrpc" @@ -32,7 +30,7 @@ func (p *DummyPublisher) Publish(filePath, accountID string, rawQuery []byte) [] p.filePath = filePath p.accountID = accountID p.rawQuery = rawQuery - return []byte(lbrynet.ExampleStreamCreateResponse) + return []byte(expectedStreamCreateResponse) } func TestUploadHandler(t *testing.T) { @@ -50,13 +48,13 @@ func TestUploadHandler(t *testing.T) { respBody, _ := ioutil.ReadAll(response.Body) assert.Equal(t, http.StatusOK, response.StatusCode) - assert.Equal(t, lbrynet.ExampleStreamCreateResponse, string(respBody)) + assert.Equal(t, expectedStreamCreateResponse, string(respBody)) require.True(t, publisher.called) expectedPath := path.Join(os.TempDir(), "UPldrAcc", ".*_lbry_auto_test_file") assert.Regexp(t, expectedPath, publisher.filePath) assert.Equal(t, "UPldrAcc", publisher.accountID) - assert.Equal(t, lbrynet.ExampleStreamCreateRequest, string(publisher.rawQuery)) + assert.Equal(t, expectedStreamCreateRequest, string(publisher.rawQuery)) _, err = os.Stat(publisher.filePath) assert.True(t, os.IsNotExist(err)) @@ -99,7 +97,7 @@ func TestUploadHandlerSystemError(t *testing.T) { jsonPayload, err := writer.CreateFormField(JSONRPCFieldName) require.NoError(t, err) - jsonPayload.Write([]byte(lbrynet.ExampleStreamCreateRequest)) + jsonPayload.Write([]byte(expectedStreamCreateRequest)) // <--- Not calling writer.Close() here to create an unexpected EOF diff --git a/app/publish/testing.go b/app/publish/testing.go index 004f9013..f927c3e2 100644 --- a/app/publish/testing.go +++ b/app/publish/testing.go @@ -7,8 +7,6 @@ import ( "net/http" "testing" - "github.com/lbryio/lbrytv/internal/lbrynet" - "github.com/stretchr/testify/require" ) @@ -26,7 +24,7 @@ func CreatePublishRequest(t *testing.T, data []byte) *http.Request { jsonPayload, err := writer.CreateFormField(JSONRPCFieldName) require.NoError(t, err) - jsonPayload.Write([]byte(lbrynet.ExampleStreamCreateRequest)) + jsonPayload.Write([]byte(expectedStreamCreateRequest)) writer.Close() @@ -36,3 +34,149 @@ func CreatePublishRequest(t *testing.T, data []byte) *http.Request { req.Header.Set("Content-Type", writer.FormDataContentType()) return req } + +var expectedStreamCreateResponse = ` +{ + "jsonrpc": "2.0", + "result": { + "height": -2, + "hex": "0100000001b25ac56e2fda6353b732863e338e205a19d1d2f4e38145048ee501e373fd8585010000006a4730440220205c1cea74188145c8d3200ef2914b5852c8a3b151876c9d9431e9b52e82b3e0022061169e87088e2fd0759d457d0a444a9445d404b64358d5cbac08c5ab950dca6c012103ebc2c0ec16d9e24b5ebcb4bf957ddc9fd7a80376d1cff0d79f5d65e381d7fe42ffffffff0200e1f50500000000fddc01b50b626c616e6b2d696d6167654db1010127876157202060e91daaf771f57c2b78c254f9cb24eda15eb1995dfe4ea874fa93396c62e1fe82612e6b9b786ea0c55166e98e7880da5e3b48ef29ab4d1a9c83f71482c22a4acad548c27a5f5643550d0434f3b00ae6010a82010a306c7df435d412c603390f593ef658c199817c7830ba3f16b7eadd8f99fa50e85dbd0d2b3dc61eadc33fe096e3872d1545120f746d706e6b745f343962712e706e6718632209696d6167652f706e673230eda7090b2d59beb0d77de489961cb73bbc73bbbb80d2c3c0e5f547b8c07dc0eded9627ce12872ca86a20a51d54ae3c4b120650696361736f1a0d5075626c696320446f6d61696e2218687474703a2f2f7075626c69632d646f6d61696e2e6f72672880f1c3ea053222080112196f147b27d1c70b5fb7ff1560d32bfda68507a89a0f214e74e0188087a70e520408051007420b426c616e6b20496d6167654a184120626c616e6b20504e472074686174206973203578372e52252a23687474703a2f2f736d616c6c6d656469612e636f6d2f7468756d626e61696c2e6a70675a05626c616e6b5a03617274620208016a1308ec0112024e481a0a4d616e636865737465726d7576a914147b27d1c70b5fb7ff1560d32bfda68507a89a0f88acac5e7d1d000000001976a914d7d23f1f17bdd156052ea8c496a95070157fb6ab88ac00000000", + "inputs": [ + { + "address": "n4SAW6U5NeYRqQTdos4cLMgtbWRBFW8X16", + "amount": "5.969662", + "confirmations": 2, + "height": 213, + "is_change": true, + "is_mine": true, + "nout": 1, + "timestamp": 1565587608, + "txid": "8585fd73e301e58e044581e3f4d2d1195a208e333e8632b75363da2f6ec55ab2", + "type": "payment" + } + ], + "outputs": [ + { + "address": "mhPFLtT7YzmNfMuQYr4PQXAJdtaTKWRLFy", + "amount": "1.0", + "claim_id": "5cfb92c3e6a80aedee5282c3f64b565bc6965562", + "claim_op": "create", + "confirmations": -2, + "height": -2, + "is_channel_signature_valid": true, + "meta": {}, + "name": "blank-image", + "normalized_name": "blank-image", + "nout": 0, + "permanent_url": "lbry://blank-image#5cfb92c3e6a80aedee5282c3f64b565bc6965562", + "signing_channel": { + "address": "mvE3pR2rH5mP1Hx8UEipnPt3Atp89tXqVw", + "amount": "1.0", + "claim_id": "cbf954c2782b7cf571f7aa1de960202057618727", + "claim_op": "update", + "confirmations": 5, + "height": 210, + "is_change": false, + "is_mine": true, + "meta": {}, + "name": "@channel", + "normalized_name": "@channel", + "nout": 0, + "permanent_url": "lbry://@channel#cbf954c2782b7cf571f7aa1de960202057618727", + "timestamp": 1565587607, + "txid": "794fc94e7ac645d5fc06c14e5ac9be9d9afa53cd540a349ee276662b23e21396", + "type": "claim", + "value": { + "public_key": "3056301006072a8648ce3d020106052b8104000a0342000404b644588c6a32f425fa8c2c3b0404898c79d405d1e90783adcf9a2bdbad505012f1e6be38f7837b69d5f2a1a1959135701780f01fc91c396158c4b1b9b1e304", + "public_key_id": "mrPWGtFam2wwv7D1QRgXXrXePLqUGdKaCb", + "title": "New Channel" + }, + "value_type": "channel" + }, + "timestamp": null, + "txid": "474e26f1aceebbdbbbad02afd37dd39aa3eb221098fa8a4073b1117264422e98", + "type": "claim", + "value": { + "author": "Picaso", + "description": "A blank PNG that is 5x7.", + "fee": { + "address": "mhPFLtT7YzmNfMuQYr4PQXAJdtaTKWRLFy", + "amount": "0.3", + "currency": "LBC" + }, + "image": { + "height": 7, + "width": 5 + }, + "languages": [ + "en" + ], + "license": "Public Domain", + "license_url": "http://public-domain.org", + "locations": [ + { + "city": "Manchester", + "country": "US", + "state": "NH" + } + ], + "release_time": "1565587584", + "source": { + "hash": "6c7df435d412c603390f593ef658c199817c7830ba3f16b7eadd8f99fa50e85dbd0d2b3dc61eadc33fe096e3872d1545", + "media_type": "image/png", + "name": "tmpnkt_49bq.png", + "sd_hash": "eda7090b2d59beb0d77de489961cb73bbc73bbbb80d2c3c0e5f547b8c07dc0eded9627ce12872ca86a20a51d54ae3c4b", + "size": "99" + }, + "stream_type": "image", + "tags": [ + "blank", + "art" + ], + "thumbnail": { + "url": "http://smallmedia.com/thumbnail.jpg" + }, + "title": "Blank Image" + }, + "value_type": "stream" + }, + { + "address": "n1C7SV6XSvTgHK84pMQ23KZLszCsm53T3Q", + "amount": "4.947555", + "confirmations": -2, + "height": -2, + "nout": 1, + "timestamp": null, + "txid": "474e26f1aceebbdbbbad02afd37dd39aa3eb221098fa8a4073b1117264422e98", + "type": "payment" + } + ], + "total_fee": "0.022107", + "total_input": "5.969662", + "total_output": "5.947555", + "txid": "474e26f1aceebbdbbbad02afd37dd39aa3eb221098fa8a4073b1117264422e98" + } + } +` + +var expectedStreamCreateRequest = ` +{ + "jsonrpc": "2.0", + "method": "stream_create", + "params": { + "name": "test", + "title": "test", + "description": "test description", + "bid": "0.10000000", + "languages": [ + "en" + ], + "tags": [], + "thumbnail_url": "http://smallmedia.com/thumbnail.jpg", + "license": "None", + "release_time": 1567580184, + "file_path": "/Users/silence/Desktop/tenor.gif" + }, + "id": 1567580184168 +} + ` diff --git a/app/sdkrouter/sdkrouter.go b/app/sdkrouter/sdkrouter.go index 4eaf4afd..ae384833 100644 --- a/app/sdkrouter/sdkrouter.go +++ b/app/sdkrouter/sdkrouter.go @@ -3,6 +3,7 @@ package sdkrouter import ( "database/sql" "errors" + "fmt" "math/rand" "regexp" "sort" @@ -134,8 +135,8 @@ func (r *Router) reloadServersFromDB() { func (r *Router) setServers(servers []*models.LbrynetServer) { if len(servers) == 0 { - logger.Log().Fatal("Setting servers to empty list") - // TODO: fatal? really? maybe just don't update the servers in this case? + logger.Log().Error("Setting servers to empty list") + return } // we do this partially to make sure that ids are assigned to servers more consistently, @@ -219,3 +220,14 @@ func getUserID(walletID string) int { func getServerForUserID(userID, numServers int) int { return userID % numServers } + +func (r *Router) Client(userID int) *ljsonrpc.Client { + c := ljsonrpc.NewClient(r.GetServer(WalletID(userID)).Address) + //c.SetRPCTimeout(5 * time.Second) + return c +} + +// WalletID formats user ID to use as an LbrynetServer wallet ID. +func WalletID(userID int) string { + return fmt.Sprintf("lbrytv-id.%d.wallet", userID) +} diff --git a/app/sdkrouter/sdkrouter_test.go b/app/sdkrouter/sdkrouter_test.go index 0d094c7c..6479b2bc 100644 --- a/app/sdkrouter/sdkrouter_test.go +++ b/app/sdkrouter/sdkrouter_test.go @@ -8,8 +8,6 @@ import ( "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/storage" "github.com/lbryio/lbrytv/internal/test" - "github.com/lbryio/lbrytv/util/wallet" - "github.com/stretchr/testify/assert" ) @@ -42,7 +40,7 @@ func TestServerOrder(t *testing.T) { sdkRouter := New(servers) for i := 0; i < 100; i++ { - server := sdkRouter.GetServer(wallet.MakeID(i)).Address + server := sdkRouter.GetServer(WalletID(i)).Address assert.Equal(t, fmt.Sprintf("%d", i%len(servers)), server) } } @@ -51,7 +49,7 @@ func TestOverrideLbrynetDefaultConf(t *testing.T) { address := "http://space.com:1234" config.Override("LbrynetServers", map[string]string{"x": address}) defer config.RestoreOverridden() - server := New(config.GetLbrynetServers()).GetServer(wallet.MakeID(343465345)) + server := New(config.GetLbrynetServers()).GetServer(WalletID(343465345)) assert.Equal(t, address, server.Address) } @@ -60,7 +58,7 @@ func TestOverrideLbrynetConf(t *testing.T) { config.Override("Lbrynet", address) config.Override("LbrynetServers", map[string]string{}) defer config.RestoreOverridden() - server := New(config.GetLbrynetServers()).GetServer(wallet.MakeID(1343465345)) + server := New(config.GetLbrynetServers()).GetServer(WalletID(1343465345)) assert.Equal(t, address, server.Address) } diff --git a/app/users/remote.go b/app/users/remote.go index fd26afa4..ce797326 100644 --- a/app/users/remote.go +++ b/app/users/remote.go @@ -9,14 +9,13 @@ import ( "github.com/lbryio/lbry.go/v2/extras/lbryinc" ) -// RemoteUser encapsulates internal-apis user data -type RemoteUser struct { +// remoteUser encapsulates internal-apis user data +type remoteUser struct { ID int HasVerifiedEmail bool } -func getRemoteUser(token string, remoteIP string) (*RemoteUser, error) { - u := &RemoteUser{} +func getRemoteUser(token string, remoteIP string) (*remoteUser, error) { c := lbryinc.NewClient(token, &lbryinc.ClientOpts{ ServerAddress: config.GetInternalAPIHost(), RemoteIP: remoteIP, @@ -31,9 +30,11 @@ func getRemoteUser(token string, remoteIP string) (*RemoteUser, error) { metrics.IAPIAuthFailedDurations.Observe(duration) return nil, err } + metrics.IAPIAuthSuccessDurations.Observe(duration) - u.ID = int(r["user_id"].(float64)) - u.HasVerifiedEmail = r["has_verified_email"].(bool) - return u, nil + return &remoteUser{ + ID: int(r["user_id"].(float64)), + HasVerifiedEmail: r["has_verified_email"].(bool), + }, nil } diff --git a/app/users/users.go b/app/users/users.go index 15db15b9..9784011c 100644 --- a/app/users/users.go +++ b/app/users/users.go @@ -23,14 +23,8 @@ type WalletService struct { // TokenHeader is the name of HTTP header which is supplied by client and should contain internal-api auth_token. const TokenHeader string = "X-Lbry-Auth-Token" -const idPrefix string = "id:" const errUniqueViolation = "23505" -type savedAccFields struct { - ID string - PublicKey string -} - // Retriever is an interface for user retrieval by internal-apis auth token type Retriever interface { Retrieve(query Query) (*models.User, error) @@ -44,21 +38,14 @@ type Query struct { // NewWalletService returns WalletService instance for retrieving or creating wallet-based user records and accounts. func NewWalletService(r *sdkrouter.Router) *WalletService { - s := &WalletService{Logger: monitor.NewModuleLogger("users"), Router: r} - return s -} - -func (s *WalletService) getDBUser(id int) (*models.User, error) { - return models.Users(models.UserWhere.ID.EQ(id)).OneG() + return &WalletService{Logger: monitor.NewModuleLogger("users"), Router: r} } func (s *WalletService) createDBUser(id int) (*models.User, error) { log := s.Logger.LogF(monitor.F{"id": id}) - u := &models.User{} - u.ID = id + u := &models.User{ID: id} err := u.InsertG(boil.Infer()) - if err != nil { // Check if we encountered a primary key violation, it would mean another routine // fired from another request has managed to create a user before us so we should try retrieving it again. @@ -66,11 +53,7 @@ func (s *WalletService) createDBUser(id int) (*models.User, error) { case *pq.Error: if baseErr.Code == errUniqueViolation && baseErr.Column == "users_pkey" { log.Debug("user creation conflict, trying to retrieve the local user again") - u, retryErr := s.getDBUser(id) - if retryErr != nil { - return nil, retryErr - } - return u, nil + return getDBUser(id) } default: log.Error("unknown error encountered while creating user: ", err) @@ -82,63 +65,38 @@ func (s *WalletService) createDBUser(id int) (*models.User, error) { // Retrieve gets user by internal-apis auth token provided in the supplied Query. func (s *WalletService) Retrieve(q Query) (*models.User, error) { - var ( - localUser *models.User - lbrynetServer *models.LbrynetServer - wid string - ) - token := q.Token - log := s.Logger.LogF(monitor.F{monitor.TokenF: token}) remoteUser, err := getRemoteUser(token, q.MetaRemoteIP) if err != nil { - return nil, s.LogErrorAndReturn(log, "cannot authenticate user with internal-apis: %v", err) + msg := "cannot authenticate user with internal-apis: %v" + log.Errorf(msg, err) + return nil, fmt.Errorf(msg, err) } - - // Update log entry with extra context data - log = s.Logger.LogF(monitor.F{ - monitor.TokenF: token, - "id": remoteUser.ID, - "has_email": remoteUser.HasVerifiedEmail, - }) if !remoteUser.HasVerifiedEmail { return nil, nil } - localUser, errStorage := s.getDBUser(remoteUser.ID) - if errStorage == sql.ErrNoRows { + log.Data["id"] = remoteUser.ID + log.Data["has_email"] = remoteUser.HasVerifiedEmail + + localUser, err := getDBUser(remoteUser.ID) + if err != nil && err != sql.ErrNoRows { + return nil, err + } else if err == sql.ErrNoRows { log.Infof("user not found in the database, creating") localUser, err = s.createDBUser(remoteUser.ID) if err != nil { return nil, err } - - lbrynetServer, wid, err = s.createWallet(localUser) - if err != nil { - return nil, err - } - - err := s.postCreateUpdate(localUser, lbrynetServer, wid) - if err != nil { - return nil, err - } - - log.Data["wallet_id"] = wid - } else if errStorage != nil { - return nil, errStorage + } else if localUser.WalletID == "" { + // This scenario may happen for legacy users who are present in the database but don't have a wallet yet + log.Warnf("user %d doesn't have wallet ID set", localUser.ID) } - // This scenario may happen for legacy users who are present in the database but don't have a wallet yet if localUser.WalletID == "" { - log.Warn("user doesn't have wallet ID set") - lbrynetServer, wid, err = s.createWallet(localUser) - if err != nil { - return nil, err - } - - err := s.postCreateUpdate(localUser, lbrynetServer, wid) + err := createWalletForUser(localUser, s.Router, log) if err != nil { return nil, err } @@ -147,24 +105,28 @@ func (s *WalletService) Retrieve(q Query) (*models.User, error) { return localUser, nil } -func (s *WalletService) createWallet(u *models.User) (*models.LbrynetServer, string, error) { - return lbrynet.InitializeWallet(s.Router, u.ID) -} +// TODO: this is the function where users are assigned to SDKs. assign them randomly +func createWalletForUser(user *models.User, router *sdkrouter.Router, log *logrus.Entry) error { + // either a new user or a legacy user without a wallet + walletID, err := lbrynet.InitializeWallet(router, user.ID) + if err != nil { + return err + } -func (s *WalletService) postCreateUpdate(u *models.User, server *models.LbrynetServer, wid string) error { - s.Logger.LogF(monitor.F{"id": u.ID, "wallet_id": wid}).Info("saving wallet ID to user record") - u.WalletID = wid - if server.ID > 0 { //Ensure server is from DB - u.LbrynetServerID.SetValid(server.ID) + log.Data["wallet_id"] = walletID + log.Info("saving wallet ID to user record") + + user.WalletID = walletID + + server := router.GetServer(sdkrouter.WalletID(user.ID)) + if server.ID > 0 { // Ensure server is from DB + user.LbrynetServerID.SetValid(server.ID) } - _, err := u.UpdateG(boil.Infer()) + _, err = user.UpdateG(boil.Infer()) return err } -// LogErrorAndReturn logs error with rich context and returns an error object -// so it can be returned from the function -func (s *WalletService) LogErrorAndReturn(log *logrus.Entry, message string, a ...interface{}) error { - log.Errorf(message, a...) - return fmt.Errorf(message, a...) +func getDBUser(id int) (*models.User, error) { + return models.Users(models.UserWhere.ID.EQ(id)).OneG() } diff --git a/app/users/users_test.go b/app/users/users_test.go index 22577e08..7ee64d27 100644 --- a/app/users/users_test.go +++ b/app/users/users_test.go @@ -14,8 +14,6 @@ import ( "github.com/lbryio/lbrytv/internal/lbrynet" "github.com/lbryio/lbrytv/internal/storage" "github.com/lbryio/lbrytv/models" - "github.com/lbryio/lbrytv/util/wallet" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -58,7 +56,7 @@ func setupCleanupDummyUser(rt *sdkrouter.Router, uidParam ...int) func() { return func() { ts.Close() config.RestoreOverridden() - lbrynet.WalletRemove(rt, uid) + lbrynet.UnloadWallet(rt, uid) } } @@ -67,7 +65,7 @@ func TestWalletServiceRetrieveNewUser(t *testing.T) { setupDBTables() defer setupCleanupDummyUser(rt)() - wid := wallet.MakeID(dummyUserID) + wid := sdkrouter.WalletID(dummyUserID) svc := NewWalletService(rt) u, err := svc.Retrieve(Query{Token: "abc"}) require.NoError(t, err, errors.Unwrap(err)) diff --git a/internal/lbrynet/errors.go b/internal/lbrynet/errors.go index dec2f222..13fcdccc 100644 --- a/internal/lbrynet/errors.go +++ b/internal/lbrynet/errors.go @@ -1,101 +1,48 @@ package lbrynet import ( + "errors" "fmt" "regexp" ) -type AccountNotFound struct { - UID int - Err error -} - -type AccountConflict struct { - UID int - Err error -} - type WalletError struct { - error - UID int - Err error -} - -type WalletExists struct { - WalletError -} - -type WalletNeedsLoading struct { - WalletError -} - -type WalletAlreadyLoaded struct { - WalletError -} - -type WalletNotFound struct { - WalletError -} - -type WalletNotLoaded struct { - WalletError -} - -func (e AccountNotFound) Error() string { - return fmt.Sprintf("couldn't find account for %v in lbrynet", e.UID) -} - -func (e AccountConflict) Error() string { - return fmt.Sprintf("account for %v already registered with lbrynet", e.UID) -} - -// Workaround for non-existent SDK error codes -var reWalletExists = regexp.MustCompile(`Wallet at path .+ already exists and is loaded`) -var reWalletNeedsLoading = regexp.MustCompile(`Wallet at path .+ already exists, use 'wallet_add' to load wallet`) -var reWalletAlreadyLoaded = regexp.MustCompile(`Wallet at path .+ is already loaded`) -var reWalletNotFound = regexp.MustCompile(`Wallet at path .+ was not found`) -var reWalletNotLoaded = regexp.MustCompile(`Couldn't find wallet:`) + UserID int + Err error +} + +func (e WalletError) Error() string { return fmt.Sprintf("user %d: %s", e.UserID, e.Err.Error()) } +func (e WalletError) Unwrap() error { return e.Err } + +var ( + ErrWalletNotFound = errors.New("wallet not found") + ErrWalletExists = errors.New("wallet exists and is loaded") + ErrWalletNeedsLoading = errors.New("wallet exists and needs to be loaded") + ErrWalletNotLoaded = errors.New("wallet is not loaded") + ErrWalletAlreadyLoaded = errors.New("wallet is already loaded") + + // Workaround for non-existent SDK error codes + reWalletNotFound = regexp.MustCompile(`Wallet at path .+ was not found`) + reWalletExists = regexp.MustCompile(`Wallet at path .+ already exists and is loaded`) + reWalletNeedsLoading = regexp.MustCompile(`Wallet at path .+ already exists, use 'wallet_add' to load wallet`) + reWalletNotLoaded = regexp.MustCompile(`Couldn't find wallet:`) + reWalletAlreadyLoaded = regexp.MustCompile(`Wallet at path .+ is already loaded`) +) // NewWalletError converts plain SDK error to the typed one -func NewWalletError(uid int, err error) error { - wErr := WalletError{UID: uid, Err: err} - +func NewWalletError(userID int, err error) error { switch { + case reWalletNotFound.MatchString(err.Error()): + return WalletError{UserID: userID, Err: ErrWalletNotFound} case reWalletExists.MatchString(err.Error()): - return WalletExists{wErr} + return WalletError{UserID: userID, Err: ErrWalletExists} case reWalletNeedsLoading.MatchString(err.Error()): - return WalletNeedsLoading{wErr} - case reWalletAlreadyLoaded.MatchString(err.Error()): - return WalletAlreadyLoaded{wErr} - case reWalletNotFound.MatchString(err.Error()): - return WalletNotFound{wErr} + return WalletError{UserID: userID, Err: ErrWalletNeedsLoading} case reWalletNotLoaded.MatchString(err.Error()): - return WalletNotLoaded{wErr} + return WalletError{UserID: userID, Err: ErrWalletNotLoaded} + case reWalletAlreadyLoaded.MatchString(err.Error()): + return WalletError{UserID: userID, Err: ErrWalletAlreadyLoaded} default: - return wErr + return WalletError{UserID: userID, Err: err} } } - -func (e WalletError) Unwrap() error { - return e.Err -} - -func (e WalletError) Error() string { - return fmt.Sprintf("unknown wallet error: %v", e.Unwrap()) -} - -func (e WalletExists) Error() string { - return "wallet is already loaded" -} - -func (e WalletNeedsLoading) Error() string { - return "wallet already exists but is not loaded" -} - -func (e WalletAlreadyLoaded) Error() string { - return "wallet is already loaded" -} - -func (e WalletNotLoaded) Error() string { - return "wallet not found" -} diff --git a/internal/lbrynet/errors_test.go b/internal/lbrynet/errors_test.go index 8ad6c255..fa5bb7f5 100644 --- a/internal/lbrynet/errors_test.go +++ b/internal/lbrynet/errors_test.go @@ -2,17 +2,15 @@ package lbrynet import ( "errors" - "fmt" "testing" "github.com/stretchr/testify/assert" ) func TestWalletAlreadyLoaded(t *testing.T) { - origErr := fmt.Errorf("Wallet at path /tmp/123 is already loaded") - walletErr := &WalletAlreadyLoaded{} - err := NewWalletError(123, origErr) - + walletErr := &WalletError{} + err := NewWalletError(123, errors.New("Wallet at path /tmp/123 is already loaded")) + assert.True(t, errors.Is(err, ErrWalletAlreadyLoaded)) assert.True(t, errors.As(err, walletErr)) - assert.Equal(t, 123, walletErr.UID) + assert.Equal(t, 123, walletErr.UserID) } diff --git a/internal/lbrynet/lbrynet.go b/internal/lbrynet/lbrynet.go index 16ac1fd6..6795812d 100644 --- a/internal/lbrynet/lbrynet.go +++ b/internal/lbrynet/lbrynet.go @@ -5,47 +5,47 @@ import ( "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/internal/monitor" - "github.com/lbryio/lbrytv/models" - "github.com/lbryio/lbrytv/util/wallet" ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" ) -const accountNamePrefix = "lbrytv-user-id:" -const accountNameTemplate = accountNamePrefix + "%v" - -var defaultWalletOpts = ljsonrpc.WalletCreateOpts{SkipOnStartup: true, CreateAccount: true, SingleKey: true} - var Logger = monitor.NewModuleLogger("lbrynet") -// InitializeWallet creates a wallet that can be immediately used -// in subsequent commands. +// InitializeWallet creates a wallet that can be immediately used in subsequent commands. // It can recover from errors like existing wallets, but if a wallet is known to exist -// (eg. a wallet ID stored in the database already), AddWallet should be called instead. -func InitializeWallet(rt *sdkrouter.Router, uid int) (*models.LbrynetServer, string, error) { - wid := wallet.MakeID(uid) - log := Logger.LogF(monitor.F{"wallet_id": wid, "user_id": uid}) - wallet, lbrynetServer, err := CreateWallet(rt, uid) - if err != nil { - if errors.As(err, &WalletExists{}) { - log.Warn(err.Error()) - return lbrynetServer, wid, nil - } else if errors.As(err, &WalletNeedsLoading{}) { - log.Info(err.Error()) - wallet, err = AddWallet(rt, uid) - if err != nil && errors.As(err, &WalletAlreadyLoaded{}) { +// (eg. a wallet ID stored in the database already), loadWallet() should be called instead. +func InitializeWallet(rt *sdkrouter.Router, userID int) (string, error) { + w, err := createWallet(rt, userID) + if err == nil { + return w.ID, nil + } + + walletID := sdkrouter.WalletID(userID) + log := Logger.LogF(monitor.F{"user_id": userID}) + + if errors.Is(err, ErrWalletExists) { + log.Warn(err.Error()) + return walletID, nil + } + + if errors.Is(err, ErrWalletNeedsLoading) { + log.Info(err.Error()) + w, err = loadWallet(rt, userID) + if err != nil { + if errors.Is(err, ErrWalletAlreadyLoaded) { log.Info(err.Error()) - return lbrynetServer, wid, nil + return walletID, nil } - } else { - log.Error("don't know how to recover from error: ", err) - return lbrynetServer, "", err + return "", err } + return w.ID, nil } - return lbrynetServer, wallet.ID, nil + + log.Errorf("don't know how to recover from error: %v", err) + return "", err } -// CreateWallet creates a new wallet with the LbrynetServer. +// createWallet creates a new wallet on the LbrynetServer. // Returned error doesn't necessarily mean that the wallet is not operational: // // if errors.Is(err, lbrynet.WalletExists) { @@ -53,49 +53,40 @@ func InitializeWallet(rt *sdkrouter.Router, uid int) (*models.LbrynetServer, str // } // // if errors.Is(err, lbrynet.WalletNeedsLoading) { -// // AddWallet() needs to be called before the wallet can be used +// // loadWallet() needs to be called before the wallet can be used // } -func CreateWallet(rt *sdkrouter.Router, uid int) (*ljsonrpc.Wallet, *models.LbrynetServer, error) { - wid := wallet.MakeID(uid) - log := Logger.LogF(monitor.F{"wallet_id": wid, "user_id": uid}) - lbrynetServer := rt.GetServer(wid) - client := ljsonrpc.NewClient(lbrynetServer.Address) - wallet, err := client.WalletCreate(wid, &defaultWalletOpts) +func createWallet(rt *sdkrouter.Router, userID int) (*ljsonrpc.Wallet, error) { + wallet, err := rt.Client(userID).WalletCreate(sdkrouter.WalletID(userID), &ljsonrpc.WalletCreateOpts{ + SkipOnStartup: true, CreateAccount: true, SingleKey: true}) if err != nil { - return nil, lbrynetServer, NewWalletError(uid, err) + return nil, NewWalletError(userID, err) } - log.Info("wallet created") - return wallet, lbrynetServer, nil + Logger.LogF(monitor.F{"user_id": userID}).Info("wallet created") + return wallet, nil } -// AddWallet loads an existing wallet in the LbrynetServer. +// loadWallet loads an existing wallet in the LbrynetServer. // May return errors: // WalletAlreadyLoaded - wallet is already loaded and operational // WalletNotFound - wallet file does not exist and won't be loaded. -func AddWallet(rt *sdkrouter.Router, uid int) (*ljsonrpc.Wallet, error) { - wid := wallet.MakeID(uid) - log := Logger.LogF(monitor.F{"wallet_id": wid, "user_id": uid}) - client := ljsonrpc.NewClient(rt.GetServer(wid).Address) - wallet, err := client.WalletAdd(wid) +func loadWallet(rt *sdkrouter.Router, userID int) (*ljsonrpc.Wallet, error) { + wallet, err := rt.Client(userID).WalletAdd(sdkrouter.WalletID(userID)) if err != nil { - return nil, NewWalletError(uid, err) + return nil, NewWalletError(userID, err) } - log.Info("wallet loaded") + Logger.LogF(monitor.F{"user_id": userID}).Info("wallet loaded") return wallet, nil } -// WalletRemove loads an existing wallet in the LbrynetServer. +// UnloadWallet unloads an existing wallet from the LbrynetServer. // May return errors: // WalletAlreadyLoaded - wallet is already loaded and operational // WalletNotFound - wallet file does not exist and won't be loaded. -func WalletRemove(rt *sdkrouter.Router, uid int) (*ljsonrpc.Wallet, error) { - wid := wallet.MakeID(uid) - log := Logger.LogF(monitor.F{"wallet_id": wid, "user_id": uid}) - client := ljsonrpc.NewClient(rt.GetServer(wid).Address) - wallet, err := client.WalletRemove(wid) +func UnloadWallet(rt *sdkrouter.Router, userID int) error { + _, err := rt.Client(userID).WalletRemove(sdkrouter.WalletID(userID)) if err != nil { - return nil, NewWalletError(uid, err) + return NewWalletError(userID, err) } - log.Info("wallet removed") - return wallet, nil + Logger.LogF(monitor.F{"user_id": userID}).Info("wallet unloaded") + return nil } diff --git a/internal/lbrynet/lbrynet_test.go b/internal/lbrynet/lbrynet_test.go index adb7381b..473c198d 100644 --- a/internal/lbrynet/lbrynet_test.go +++ b/internal/lbrynet/lbrynet_test.go @@ -9,8 +9,6 @@ import ( "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/config" - "github.com/lbryio/lbrytv/util/wallet" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -25,34 +23,34 @@ func TestInitializeWallet(t *testing.T) { uid := rand.Int() r := sdkrouter.New(config.GetLbrynetServers()) - _, wid, err := InitializeWallet(r, uid) + wid, err := InitializeWallet(r, uid) require.NoError(t, err) - assert.Equal(t, wid, wallet.MakeID(uid)) + assert.Equal(t, wid, sdkrouter.WalletID(uid)) - _, err = WalletRemove(r, uid) + err = UnloadWallet(r, uid) require.NoError(t, err) - _, wid, err = InitializeWallet(r, uid) + wid, err = InitializeWallet(r, uid) require.NoError(t, err) - assert.Equal(t, wid, wallet.MakeID(uid)) + assert.Equal(t, wid, sdkrouter.WalletID(uid)) } -func TestCreateWalletAddWallet(t *testing.T) { +func TestCreateWalletLoadWallet(t *testing.T) { uid := rand.Int() r := sdkrouter.New(config.GetLbrynetServers()) - w, _, err := CreateWallet(r, uid) + w, err := createWallet(r, uid) require.NoError(t, err) - assert.Equal(t, w.ID, wallet.MakeID(uid)) + assert.Equal(t, w.ID, sdkrouter.WalletID(uid)) - _, _, err = CreateWallet(r, uid) + _, err = createWallet(r, uid) require.NotNil(t, err) - assert.True(t, errors.As(err, &WalletExists{})) + assert.True(t, errors.Is(err, ErrWalletExists)) - _, err = WalletRemove(r, uid) + err = UnloadWallet(r, uid) require.NoError(t, err) - w, err = AddWallet(r, uid) + w, err = loadWallet(r, uid) require.NoError(t, err) - assert.Equal(t, w.ID, wallet.MakeID(uid)) + assert.Equal(t, w.ID, sdkrouter.WalletID(uid)) } diff --git a/internal/lbrynet/testing.go b/internal/lbrynet/testing.go deleted file mode 100644 index 4f61ceeb..00000000 --- a/internal/lbrynet/testing.go +++ /dev/null @@ -1,147 +0,0 @@ -package lbrynet - -var ExampleStreamCreateResponse = ` -{ - "jsonrpc": "2.0", - "result": { - "height": -2, - "hex": "0100000001b25ac56e2fda6353b732863e338e205a19d1d2f4e38145048ee501e373fd8585010000006a4730440220205c1cea74188145c8d3200ef2914b5852c8a3b151876c9d9431e9b52e82b3e0022061169e87088e2fd0759d457d0a444a9445d404b64358d5cbac08c5ab950dca6c012103ebc2c0ec16d9e24b5ebcb4bf957ddc9fd7a80376d1cff0d79f5d65e381d7fe42ffffffff0200e1f50500000000fddc01b50b626c616e6b2d696d6167654db1010127876157202060e91daaf771f57c2b78c254f9cb24eda15eb1995dfe4ea874fa93396c62e1fe82612e6b9b786ea0c55166e98e7880da5e3b48ef29ab4d1a9c83f71482c22a4acad548c27a5f5643550d0434f3b00ae6010a82010a306c7df435d412c603390f593ef658c199817c7830ba3f16b7eadd8f99fa50e85dbd0d2b3dc61eadc33fe096e3872d1545120f746d706e6b745f343962712e706e6718632209696d6167652f706e673230eda7090b2d59beb0d77de489961cb73bbc73bbbb80d2c3c0e5f547b8c07dc0eded9627ce12872ca86a20a51d54ae3c4b120650696361736f1a0d5075626c696320446f6d61696e2218687474703a2f2f7075626c69632d646f6d61696e2e6f72672880f1c3ea053222080112196f147b27d1c70b5fb7ff1560d32bfda68507a89a0f214e74e0188087a70e520408051007420b426c616e6b20496d6167654a184120626c616e6b20504e472074686174206973203578372e52252a23687474703a2f2f736d616c6c6d656469612e636f6d2f7468756d626e61696c2e6a70675a05626c616e6b5a03617274620208016a1308ec0112024e481a0a4d616e636865737465726d7576a914147b27d1c70b5fb7ff1560d32bfda68507a89a0f88acac5e7d1d000000001976a914d7d23f1f17bdd156052ea8c496a95070157fb6ab88ac00000000", - "inputs": [ - { - "address": "n4SAW6U5NeYRqQTdos4cLMgtbWRBFW8X16", - "amount": "5.969662", - "confirmations": 2, - "height": 213, - "is_change": true, - "is_mine": true, - "nout": 1, - "timestamp": 1565587608, - "txid": "8585fd73e301e58e044581e3f4d2d1195a208e333e8632b75363da2f6ec55ab2", - "type": "payment" - } - ], - "outputs": [ - { - "address": "mhPFLtT7YzmNfMuQYr4PQXAJdtaTKWRLFy", - "amount": "1.0", - "claim_id": "5cfb92c3e6a80aedee5282c3f64b565bc6965562", - "claim_op": "create", - "confirmations": -2, - "height": -2, - "is_channel_signature_valid": true, - "meta": {}, - "name": "blank-image", - "normalized_name": "blank-image", - "nout": 0, - "permanent_url": "lbry://blank-image#5cfb92c3e6a80aedee5282c3f64b565bc6965562", - "signing_channel": { - "address": "mvE3pR2rH5mP1Hx8UEipnPt3Atp89tXqVw", - "amount": "1.0", - "claim_id": "cbf954c2782b7cf571f7aa1de960202057618727", - "claim_op": "update", - "confirmations": 5, - "height": 210, - "is_change": false, - "is_mine": true, - "meta": {}, - "name": "@channel", - "normalized_name": "@channel", - "nout": 0, - "permanent_url": "lbry://@channel#cbf954c2782b7cf571f7aa1de960202057618727", - "timestamp": 1565587607, - "txid": "794fc94e7ac645d5fc06c14e5ac9be9d9afa53cd540a349ee276662b23e21396", - "type": "claim", - "value": { - "public_key": "3056301006072a8648ce3d020106052b8104000a0342000404b644588c6a32f425fa8c2c3b0404898c79d405d1e90783adcf9a2bdbad505012f1e6be38f7837b69d5f2a1a1959135701780f01fc91c396158c4b1b9b1e304", - "public_key_id": "mrPWGtFam2wwv7D1QRgXXrXePLqUGdKaCb", - "title": "New Channel" - }, - "value_type": "channel" - }, - "timestamp": null, - "txid": "474e26f1aceebbdbbbad02afd37dd39aa3eb221098fa8a4073b1117264422e98", - "type": "claim", - "value": { - "author": "Picaso", - "description": "A blank PNG that is 5x7.", - "fee": { - "address": "mhPFLtT7YzmNfMuQYr4PQXAJdtaTKWRLFy", - "amount": "0.3", - "currency": "LBC" - }, - "image": { - "height": 7, - "width": 5 - }, - "languages": [ - "en" - ], - "license": "Public Domain", - "license_url": "http://public-domain.org", - "locations": [ - { - "city": "Manchester", - "country": "US", - "state": "NH" - } - ], - "release_time": "1565587584", - "source": { - "hash": "6c7df435d412c603390f593ef658c199817c7830ba3f16b7eadd8f99fa50e85dbd0d2b3dc61eadc33fe096e3872d1545", - "media_type": "image/png", - "name": "tmpnkt_49bq.png", - "sd_hash": "eda7090b2d59beb0d77de489961cb73bbc73bbbb80d2c3c0e5f547b8c07dc0eded9627ce12872ca86a20a51d54ae3c4b", - "size": "99" - }, - "stream_type": "image", - "tags": [ - "blank", - "art" - ], - "thumbnail": { - "url": "http://smallmedia.com/thumbnail.jpg" - }, - "title": "Blank Image" - }, - "value_type": "stream" - }, - { - "address": "n1C7SV6XSvTgHK84pMQ23KZLszCsm53T3Q", - "amount": "4.947555", - "confirmations": -2, - "height": -2, - "nout": 1, - "timestamp": null, - "txid": "474e26f1aceebbdbbbad02afd37dd39aa3eb221098fa8a4073b1117264422e98", - "type": "payment" - } - ], - "total_fee": "0.022107", - "total_input": "5.969662", - "total_output": "5.947555", - "txid": "474e26f1aceebbdbbbad02afd37dd39aa3eb221098fa8a4073b1117264422e98" - } - } -` - -var ExampleStreamCreateRequest = ` -{ - "jsonrpc": "2.0", - "method": "stream_create", - "params": { - "name": "test", - "title": "test", - "description": "test description", - "bid": "0.10000000", - "languages": [ - "en" - ], - "tags": [], - "thumbnail_url": "http://smallmedia.com/thumbnail.jpg", - "license": "None", - "release_time": 1567580184, - "file_path": "/Users/silence/Desktop/tenor.gif" - }, - "id": 1567580184168 -} - ` diff --git a/util/wallet/wallet.go b/util/wallet/wallet.go deleted file mode 100644 index c58ad64b..00000000 --- a/util/wallet/wallet.go +++ /dev/null @@ -1,10 +0,0 @@ -package wallet - -import "fmt" - -const walletNameTemplate string = "lbrytv-id.%v.wallet" - -// MakeID formats user ID to use as an LbrynetServer wallet ID. -func MakeID(uid int) string { - return fmt.Sprintf(walletNameTemplate, uid) -} From 004771046471e987e43d45c4f2d871fc1c23bd03 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Sun, 12 Apr 2020 19:35:50 -0400 Subject: [PATCH 02/18] get rid of ghetto mock rpc server in tests so we can stop using Caller interface --- api/benchmarks_test.go | 5 +- api/routes.go | 2 +- api/routes_test.go | 6 +- app/proxy/accounts_test.go | 2 +- app/proxy/client.go | 24 ++- app/proxy/client_test.go | 53 +++--- app/proxy/handlers.go | 16 +- app/proxy/main_test.go | 4 +- app/proxy/proxy.go | 248 ++++++---------------------- app/proxy/proxy_test.go | 181 ++++++++++---------- app/proxy/query.go | 154 +++++++++++++++++ app/publish/handler_test.go | 2 +- app/publish/publish.go | 10 +- app/publish/publish_test.go | 2 +- app/sdkrouter/concurrency_test.go | 14 +- app/sdkrouter/sdkrouter.go | 122 +++++++++++--- app/sdkrouter/sdkrouter_test.go | 70 ++++++-- app/users/users.go | 6 +- app/users/users_test.go | 5 +- cmd/serve.go | 2 +- internal/environment/environment.go | 4 +- internal/lbrynet/lbrynet.go | 92 ----------- internal/lbrynet/lbrynet_test.go | 56 ------- internal/metrics/routes_test.go | 2 +- internal/monitor/sentry.go | 2 +- internal/test/test.go | 76 ++++++--- internal/test/test_test.go | 6 +- server/server.go | 4 +- server/server_test.go | 4 +- 29 files changed, 587 insertions(+), 587 deletions(-) create mode 100644 app/proxy/query.go delete mode 100644 internal/lbrynet/lbrynet.go delete mode 100644 internal/lbrynet/lbrynet_test.go diff --git a/api/benchmarks_test.go b/api/benchmarks_test.go index 526ceb10..aaa418ea 100644 --- a/api/benchmarks_test.go +++ b/api/benchmarks_test.go @@ -18,7 +18,6 @@ import ( "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/app/users" "github.com/lbryio/lbrytv/config" - "github.com/lbryio/lbrytv/internal/lbrynet" "github.com/lbryio/lbrytv/internal/responses" "github.com/lbryio/lbrytv/internal/storage" "github.com/lbryio/lbrytv/models" @@ -98,7 +97,7 @@ func BenchmarkWalletCommands(b *testing.B) { svc := users.NewWalletService(rt) svc.Logger.Disable() - lbrynet.Logger.Disable() + sdkrouter.DisableLogger() log.SetOutput(ioutil.Discard) rand.Seed(time.Now().UnixNano()) @@ -111,7 +110,7 @@ func BenchmarkWalletCommands(b *testing.B) { wallets[i] = u } - handler := proxy.NewRequestHandler(proxy.NewService(proxy.Opts{SDKRouter: rt})) + handler := proxy.NewRequestHandler(proxy.NewService(rt)) b.SetParallelism(30) b.ResetTimer() diff --git a/api/routes.go b/api/routes.go index 09d02cd7..e0becb03 100644 --- a/api/routes.go +++ b/api/routes.go @@ -19,7 +19,7 @@ import ( ) // InstallRoutes sets up global API handlers -func InstallRoutes(proxyService *proxy.ProxyService, r *mux.Router) { +func InstallRoutes(proxyService *proxy.Service, r *mux.Router) { authenticator := users.NewAuthenticator(users.NewWalletService(proxyService.SDKRouter)) proxyHandler := proxy.NewRequestHandler(proxyService) upHandler, err := publish.NewUploadHandler(publish.UploadOpts{ProxyService: proxyService}) diff --git a/api/routes_test.go b/api/routes_test.go index ddb21bce..63a663a9 100644 --- a/api/routes_test.go +++ b/api/routes_test.go @@ -18,7 +18,7 @@ import ( func TestRoutesProxy(t *testing.T) { r := mux.NewRouter() - proxy := proxy.NewService(proxy.Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}) + proxy := proxy.NewService(sdkrouter.New(config.GetLbrynetServers())) req, err := http.NewRequest("POST", "/api/v1/proxy", bytes.NewBuffer([]byte(`{"method": "status"}`))) require.NoError(t, err) @@ -33,7 +33,7 @@ func TestRoutesProxy(t *testing.T) { func TestRoutesPublish(t *testing.T) { r := mux.NewRouter() - proxy := proxy.NewService(proxy.Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}) + proxy := proxy.NewService(sdkrouter.New(config.GetLbrynetServers())) req := publish.CreatePublishRequest(t, []byte("test file")) rr := httptest.NewRecorder() @@ -49,7 +49,7 @@ func TestRoutesPublish(t *testing.T) { func TestRoutesOptions(t *testing.T) { r := mux.NewRouter() - proxy := proxy.NewService(proxy.Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}) + proxy := proxy.NewService(sdkrouter.New(config.GetLbrynetServers())) req, err := http.NewRequest("OPTIONS", "/api/v1/proxy", nil) require.NoError(t, err) diff --git a/app/proxy/accounts_test.go b/app/proxy/accounts_test.go index daa292cd..ee60b6a3 100644 --- a/app/proxy/accounts_test.go +++ b/app/proxy/accounts_test.go @@ -101,5 +101,5 @@ func TestAccountSpecificWithoutToken(t *testing.T) { err := json.Unmarshal(rr.Body.Bytes(), &response) require.NoError(t, err) require.NotNil(t, response.Error) - require.Equal(t, "account identificator required", response.Error.Message) + require.Equal(t, "account identifier required", response.Error.Message) } diff --git a/app/proxy/client.go b/app/proxy/client.go index e293ff75..35773a3d 100644 --- a/app/proxy/client.go +++ b/app/proxy/client.go @@ -19,23 +19,20 @@ const walletLoadRetryWait = time.Millisecond * 100 var ClientLogger = monitor.NewModuleLogger("proxy_client") -type LbrynetClient interface { - Call(q *Query) (*jsonrpc.RPCResponse, error) -} - type Client struct { rpcClient jsonrpc.RPCClient endpoint string - wallet string + walletID string retries int } -func NewClient(endpoint string, wallet string, timeout time.Duration) LbrynetClient { +func NewClient(endpoint string, walletID string, timeout time.Duration) Client { return Client{ endpoint: endpoint, rpcClient: jsonrpc.NewClientWithOpts(endpoint, &jsonrpc.RPCClientOpts{ - HTTPClient: &http.Client{Timeout: time.Second * timeout}}), - wallet: wallet, + HTTPClient: &http.Client{Timeout: timeout}, + }), + walletID: walletID, } } @@ -69,16 +66,17 @@ func (c Client) Call(q *Query) (*jsonrpc.RPCResponse, error) { time.Sleep(walletLoadRetryWait) // Using LBRY JSON-RPC client here for easier request/response processing client := ljsonrpc.NewClient(c.endpoint) - _, err := client.WalletAdd(c.wallet) + _, err := client.WalletAdd(c.walletID) // Alert sentry on the last failed wallet load attempt if err != nil && i >= walletLoadRetries-1 { errMsg := "gave up on manually adding a wallet: %v" ClientLogger.WithFields(monitor.F{ - "wallet_id": c.wallet, "endpoint": c.endpoint, + "wallet_id": c.walletID, + "endpoint": c.endpoint, }).Errorf(errMsg, err) monitor.CaptureException( fmt.Errorf(errMsg, err), map[string]string{ - "wallet_id": c.wallet, + "wallet_id": c.walletID, "endpoint": c.endpoint, "retries": fmt.Sprintf("%v", i), }) @@ -91,10 +89,10 @@ func (c Client) Call(q *Query) (*jsonrpc.RPCResponse, error) { } if (r != nil && r.Error != nil) || err != nil { - Logger.LogFailedQuery(q.Method(), c.endpoint, c.wallet, duration, q.Params(), r.Error) + Logger.LogFailedQuery(q.Method(), c.endpoint, c.walletID, duration, q.Params(), r.Error) failureMetrics.Observe(duration) } else { - Logger.LogSuccessfulQuery(q.Method(), c.endpoint, c.wallet, duration, q.Params(), r) + Logger.LogSuccessfulQuery(q.Method(), c.endpoint, c.walletID, duration, q.Params(), r) } return r, err diff --git a/app/proxy/client_test.go b/app/proxy/client_test.go index 24f0d4d8..76c1b3cc 100644 --- a/app/proxy/client_test.go +++ b/app/proxy/client_test.go @@ -7,8 +7,7 @@ import ( "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/config" - "github.com/lbryio/lbrytv/internal/lbrynet" - + "github.com/lbryio/lbrytv/internal/test" "github.com/stretchr/testify/require" "github.com/ybbus/jsonrpc" ) @@ -55,14 +54,15 @@ func TestClientCallDoesReloadWallet(t *testing.T) { dummyUserID := rand.Intn(100) rt := sdkrouter.New(config.GetLbrynetServers()) - wid, _ := lbrynet.InitializeWallet(rt, dummyUserID) - err := lbrynet.UnloadWallet(rt, dummyUserID) + walletID, err := rt.InitializeWallet(dummyUserID) + require.NoError(t, err) + err = rt.UnloadWallet(dummyUserID) require.NoError(t, err) - c := NewClient(rt.GetServer(wid).Address, wid, time.Second*1) + c := NewClient(rt.GetServer(dummyUserID).Address, walletID, 1*time.Second) q, _ := NewQuery(newRawRequest(t, "wallet_balance", nil)) - q.SetWalletID(wid) + q.SetWalletID(walletID) r, err := c.Call(q) // err = json.Unmarshal(result, response) @@ -74,23 +74,30 @@ func TestClientCallDoesNotReloadWalletAfterOtherErrors(t *testing.T) { rand.Seed(time.Now().UnixNano()) walletID := sdkrouter.WalletID(rand.Intn(100)) - mc := NewMockRPCClient() - c := &Client{rpcClient: mc} - q, _ := NewQuery(newRawRequest(t, "wallet_balance", nil)) + srv := test.MockHTTPServer(nil) + defer srv.Close() + + c := NewClient(srv.URL, "", 0) + q, err := NewQuery(newRawRequest(t, "wallet_balance", nil)) + require.NoError(t, err) q.SetWalletID(walletID) - mc.AddNextResponse(&jsonrpc.RPCResponse{ - JSONRPC: "2.0", - Error: &jsonrpc.RPCError{ - Message: "Couldn't find wallet: //", - }, - }) - mc.AddNextResponse(&jsonrpc.RPCResponse{ - JSONRPC: "2.0", - Error: &jsonrpc.RPCError{ - Message: "Wallet at path // was not found", - }, - }) + go func() { + srv.NextResponse <- test.ResToStr(t, jsonrpc.RPCResponse{ + JSONRPC: "2.0", + Error: &jsonrpc.RPCError{ + Message: "Couldn't find wallet: //", + }, + }) + srv.NextResponse <- "" // for the wallet_add call + srv.NextResponse <- test.ResToStr(t, jsonrpc.RPCResponse{ + JSONRPC: "2.0", + Error: &jsonrpc.RPCError{ + Message: "Wallet at path // was not found", + }, + }) + srv.NoMoreResponses() + }() r, err := c.Call(q) require.NoError(t, err) @@ -99,12 +106,12 @@ func TestClientCallDoesNotReloadWalletAfterOtherErrors(t *testing.T) { func TestClientCallDoesNotReloadWalletIfAlreadyLoaded(t *testing.T) { rand.Seed(time.Now().UnixNano()) - wid := sdkrouter.WalletID(rand.Intn(100)) + walletID := sdkrouter.WalletID(rand.Intn(100)) mc := NewMockRPCClient() c := &Client{rpcClient: mc} q, _ := NewQuery(newRawRequest(t, "wallet_balance", nil)) - q.SetWalletID(wid) + q.SetWalletID(walletID) mc.AddNextResponse(&jsonrpc.RPCResponse{ JSONRPC: "2.0", diff --git a/app/proxy/handlers.go b/app/proxy/handlers.go index 0bc65292..6df640a0 100644 --- a/app/proxy/handlers.go +++ b/app/proxy/handlers.go @@ -9,16 +9,16 @@ import ( "github.com/lbryio/lbrytv/internal/responses" ) -var logger = monitor.NewModuleLogger("proxy_handlers") +var proxyHandlerLogger = monitor.NewModuleLogger("proxy_handlers") -// RequestHandler is a wrapper for passing proxy.ProxyService instance to proxy HTTP handler. +// RequestHandler is a wrapper for passing proxy.Service instance to proxy HTTP handler. type RequestHandler struct { - *ProxyService + *Service } -// NewRequestHandler initializes request handler with a provided Proxy ProxyService instance -func NewRequestHandler(svc *ProxyService) *RequestHandler { - return &RequestHandler{ProxyService: svc} +// NewRequestHandler initializes request handler with a provided Proxy Service instance +func NewRequestHandler(svc *Service) *RequestHandler { + return &RequestHandler{Service: svc} } // Handle forwards client JSON-RPC request to proxy. @@ -26,14 +26,14 @@ func (rh *RequestHandler) Handle(w http.ResponseWriter, r *http.Request) { if r.Body == nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("empty request body")) - logger.Log().Errorf("empty request body") + proxyHandlerLogger.Log().Errorf("empty request body") return } body, err := ioutil.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("error reading request body")) - logger.Log().Errorf("error reading request body: %v", err.Error()) + proxyHandlerLogger.Log().Errorf("error reading request body: %v", err.Error()) return } diff --git a/app/proxy/main_test.go b/app/proxy/main_test.go index a7f892c8..06f82b20 100644 --- a/app/proxy/main_test.go +++ b/app/proxy/main_test.go @@ -19,12 +19,12 @@ const dummyServerURL = "http://127.0.0.1:59999" const proxySuffix = "/api/v1/proxy" const testSetupWait = 200 * time.Millisecond -var svc *ProxyService +var svc *Service func TestMain(m *testing.M) { rand.Seed(time.Now().UnixNano()) - svc = NewService(Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}) + svc = NewService(sdkrouter.New(config.GetLbrynetServers())) dbConfig := config.GetDatabase() params := storage.ConnParams{ diff --git a/app/proxy/proxy.go b/app/proxy/proxy.go index aff99cd3..1d09a910 100644 --- a/app/proxy/proxy.go +++ b/app/proxy/proxy.go @@ -15,26 +15,23 @@ import ( "encoding/json" "errors" "fmt" - "strings" "time" "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/internal/monitor" - ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" - "github.com/ybbus/jsonrpc" ) -const defaultRPCTimeout = time.Second * 30 +const defaultRPCTimeout = 30 * time.Second var Logger = monitor.NewProxyLogger() type Preprocessor func(q *Query) -// ProxyService generates Caller objects and keeps execution time metrics +// Service generates Caller objects and keeps execution time metrics // for all calls proxied through those objects. -type ProxyService struct { +type Service struct { SDKRouter *sdkrouter.Router rpcTimeout time.Duration logger monitor.QueryMonitor @@ -45,197 +42,53 @@ type ProxyService struct { type Caller struct { walletID string query *jsonrpc.RPCRequest - client LbrynetClient + client Client endpoint string - service *ProxyService + service *Service preprocessor Preprocessor } -// Query is a wrapper around client JSON-RPC query for easier (un)marshaling and processing. -type Query struct { - Request *jsonrpc.RPCRequest - rawRequest []byte - walletID string -} - -// Opts is initialization parameters for NewService / proxy.ProxyService -type Opts struct { - SDKRouter *sdkrouter.Router - RPCTimeout time.Duration -} - // NewService is the entry point to proxy module. -// Normally only one instance of ProxyService should be created per running server. -func NewService(opts Opts) *ProxyService { - s := ProxyService{ - SDKRouter: opts.SDKRouter, - rpcTimeout: opts.RPCTimeout, - } - if s.rpcTimeout == 0 { - s.rpcTimeout = defaultRPCTimeout +// Normally only one instance of Service should be created per running server. +func NewService(router *sdkrouter.Router) *Service { + return &Service{ + SDKRouter: router, + rpcTimeout: defaultRPCTimeout, } - return &s +} + +func (ps *Service) SetRPCTimeout(timeout time.Duration) { + ps.rpcTimeout = timeout } // NewCaller returns an instance of Caller ready to proxy requests. // Note that `SetWalletID` needs to be called if an authenticated user is making this call. -func (ps *ProxyService) NewCaller(walletID string) *Caller { - endpoint := ps.SDKRouter.GetServer(walletID).Address - client := NewClient(endpoint, walletID, ps.rpcTimeout) - c := Caller{ +func (ps *Service) NewCaller(walletID string) *Caller { + endpoint := ps.SDKRouter.GetServer(sdkrouter.UserID(walletID)).Address + return &Caller{ walletID: walletID, - client: client, + client: NewClient(endpoint, walletID, ps.rpcTimeout), endpoint: endpoint, service: ps, } - return &c -} - -// NewQuery initializes Query object with JSON-RPC request supplied as bytes. -// The object is immediately usable and returns an error in case request parsing fails. -func NewQuery(r []byte) (*Query, error) { - q := &Query{rawRequest: r, Request: &jsonrpc.RPCRequest{}} - err := q.unmarshal() - if err != nil { - return nil, err - } - return q, nil } -func (q *Query) unmarshal() error { - err := json.Unmarshal(q.rawRequest, q.Request) +// Call method processes a raw query received from JSON-RPC client and forwards it to LbrynetServer. +// It returns a response that is ready to be sent back to the JSON-RPC client as is. +func (c *Caller) Call(rawQuery []byte) []byte { + r, err := c.call(rawQuery) if err != nil { - return err - } - if strings.TrimSpace(q.Request.Method) == "" { - return errors.New("invalid JSON-RPC request") - } - return nil -} - -// Method is a shortcut for query method. -func (q *Query) Method() string { - return q.Request.Method -} - -// Params is a shortcut for query params. -func (q *Query) Params() interface{} { - return q.Request.Params -} - -// ParamsAsMap returns query params converted to plain map. -func (q *Query) ParamsAsMap() map[string]interface{} { - if paramsMap, ok := q.Params().(map[string]interface{}); ok { - return paramsMap - } - return nil -} - -// ParamsToStruct returns query params parsed into a supplied structure. -func (q *Query) ParamsToStruct(targetStruct interface{}) error { - return ljsonrpc.Decode(q.Params(), targetStruct) -} - -// cacheHit returns true if we got a resolve query with more than `cacheResolveLongerThan` urls in it. -func (q *Query) isCacheable() bool { - if q.Method() == MethodResolve && q.Params() != nil { - paramsMap := q.Params().(map[string]interface{}) - if urls, ok := paramsMap[paramUrls].([]interface{}); ok { - if len(urls) > cacheResolveLongerThan { - return true - } - } - } else if q.Method() == MethodClaimSearch { - return true - } - return false -} - -func (q *Query) newResponse() *jsonrpc.RPCResponse { - var r jsonrpc.RPCResponse - r.ID = q.Request.ID - r.JSONRPC = q.Request.JSONRPC - return &r -} - -func (q *Query) SetWalletID(id string) { - q.walletID = id -} - -// cacheHit returns cached response or nil in case it's a miss or query shouldn't be cacheable. -func (q *Query) cacheHit() *jsonrpc.RPCResponse { - if q.isCacheable() { - if cached := responseCache.Retrieve(q.Method(), q.Params()); cached != nil { - // TODO: Temporary hack to find out why the following line doesn't work - // if mResp, ok := cResp.(map[string]interface{}); ok { - s, _ := json.Marshal(cached) - response := q.newResponse() - err := json.Unmarshal(s, &response) - if err == nil { - monitor.LogCachedQuery(q.Method()) - return response - } - } - } - return nil -} - -func (q *Query) predefinedResponse() *jsonrpc.RPCResponse { - if q.Method() == MethodStatus { - response := q.newResponse() - response.Result = getStatusResponse() - return response - } - return nil -} - -func (q *Query) validate() CallError { - if !methodInList(q.Method(), relaxedMethods) && !methodInList(q.Method(), walletSpecificMethods) { - return NewMethodError(errors.New("forbidden method")) - } - if q.ParamsAsMap() != nil { - if _, ok := q.ParamsAsMap()[forbiddenParam]; ok { - return NewParamsError(fmt.Errorf("forbidden parameter supplied: %v", forbiddenParam)) - } - } - - if !methodInList(q.Method(), relaxedMethods) { - if q.walletID == "" { - return NewParamsError(errors.New("account identificator required")) - } - if p := q.ParamsAsMap(); p != nil { - p[paramWalletID] = q.walletID - q.Request.Params = p - } else { - q.Request.Params = map[string]interface{}{paramWalletID: q.walletID} + if !errors.As(err, &InputError{}) { + monitor.CaptureException(err, map[string]string{"query": string(rawQuery), "response": fmt.Sprintf("%v", r)}) + Logger.Errorf("error calling lbrynet: %v, query: %s", err, rawQuery) } + return c.marshalError(err) } - - return nil -} - -// SetPreprocessor applies provided function to query before it's sent to the LbrynetServer. -func (c *Caller) SetPreprocessor(p Preprocessor) { - c.preprocessor = p -} - -// WalletID is an LbrynetServer wallet ID for the client this caller instance is serving. -func (c *Caller) WalletID() string { - return c.walletID -} - -func (c *Caller) marshal(r *jsonrpc.RPCResponse) ([]byte, CallError) { - serialized, err := json.MarshalIndent(r, "", " ") - if err != nil { - return nil, NewError(err) - } - return serialized, nil -} - -func (c *Caller) marshalError(e CallError) []byte { - serialized, err := json.MarshalIndent(e.AsRPCResponse(), "", " ") + serialized, err := c.marshal(r) if err != nil { - return []byte(err.Error()) + monitor.CaptureException(err) + Logger.Errorf("error marshaling response: %v", err) + return c.marshalError(err) } return serialized } @@ -246,20 +99,20 @@ func (c *Caller) call(rawQuery []byte) (*jsonrpc.RPCResponse, CallError) { return nil, NewInputError(err) } - if c.WalletID() != "" { - q.SetWalletID(c.WalletID()) + if c.walletID != "" { + q.SetWalletID(c.walletID) } - // Check for account identificator (wallet ID) for account-specific methods happens here + // Check for account identifier (wallet ID) for account-specific methods happens here if err := q.validate(); err != nil { return nil, err } - if cachedResponse := q.cacheHit(); cachedResponse != nil { - return cachedResponse, nil + if cached := q.cacheHit(); cached != nil { + return cached, nil } - if predefinedResponse := q.predefinedResponse(); predefinedResponse != nil { - return predefinedResponse, nil + if pr := q.predefinedResponse(); pr != nil { + return pr, nil } if c.preprocessor != nil { @@ -282,22 +135,23 @@ func (c *Caller) call(rawQuery []byte) (*jsonrpc.RPCResponse, CallError) { return r, nil } -// Call method processes a raw query received from JSON-RPC client and forwards it to LbrynetServer. -// It returns a response that is ready to be sent back to the JSON-RPC client as is. -func (c *Caller) Call(rawQuery []byte) []byte { - r, err := c.call(rawQuery) +func (c *Caller) marshal(r *jsonrpc.RPCResponse) ([]byte, CallError) { + serialized, err := json.MarshalIndent(r, "", " ") if err != nil { - if !errors.As(err, &InputError{}) { - monitor.CaptureException(err, map[string]string{"query": string(rawQuery), "response": fmt.Sprintf("%v", r)}) - Logger.Errorf("error calling lbrynet: %v, query: %s", err, rawQuery) - } - return c.marshalError(err) + return nil, NewError(err) } - serialized, err := c.marshal(r) + return serialized, nil +} + +func (c *Caller) marshalError(e CallError) []byte { + serialized, err := json.MarshalIndent(e.AsRPCResponse(), "", " ") if err != nil { - monitor.CaptureException(err) - Logger.Errorf("error marshaling response: %v", err) - return c.marshalError(err) + return []byte(err.Error()) } return serialized } + +// SetPreprocessor applies provided function to query before it's sent to the LbrynetServer. +func (c *Caller) SetPreprocessor(p Preprocessor) { + c.preprocessor = p +} diff --git a/app/proxy/proxy_test.go b/app/proxy/proxy_test.go index f7c4b8b1..899a5486 100644 --- a/app/proxy/proxy_test.go +++ b/app/proxy/proxy_test.go @@ -12,8 +12,8 @@ import ( "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/config" - "github.com/lbryio/lbrytv/internal/lbrynet" "github.com/lbryio/lbrytv/internal/responses" + "github.com/lbryio/lbrytv/internal/test" ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" @@ -23,8 +23,6 @@ import ( "github.com/ybbus/jsonrpc" ) -const endpoint = "http://localhost:5279/" - func newRawRequest(t *testing.T, method string, params interface{}) []byte { var ( body []byte @@ -43,27 +41,11 @@ func newRawRequest(t *testing.T, method string, params interface{}) []byte { func parseRawResponse(t *testing.T, rawCallReponse []byte, destinationVar interface{}) { var rpcResponse jsonrpc.RPCResponse - assert.NotNil(t, rawCallReponse) - json.Unmarshal(rawCallReponse, &rpcResponse) rpcResponse.GetObject(destinationVar) } -type MockClient struct { - Delay time.Duration - LastRequest jsonrpc.RPCRequest -} - -func (c *MockClient) Call(q *Query) (*jsonrpc.RPCResponse, error) { - c.LastRequest = *q.Request - time.Sleep(c.Delay) - return &jsonrpc.RPCResponse{ - JSONRPC: "2.0", - Result: "0.0", - }, nil -} - func TestNewQuery(t *testing.T) { for _, rawQ := range []string{``, ` `, `{}`, `[]`, `[{}]`, `[""]`, `""`, `" "`, `{"method": " "}`} { t.Run(rawQ, func(t *testing.T) { @@ -80,7 +62,7 @@ func TestNewCaller(t *testing.T) { "first": "http://lbrynet1", "second": "http://lbrynet2", } - svc := NewService(Opts{SDKRouter: sdkrouter.New(servers)}) + svc := NewService(sdkrouter.New(servers)) c := svc.NewCaller("") assert.Equal(t, svc, c.service) @@ -95,7 +77,7 @@ func TestNewCaller(t *testing.T) { } func TestCallerСall(t *testing.T) { - c := NewService(Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}).NewCaller("abc") + c := NewService(sdkrouter.New(config.GetLbrynetServers())).NewCaller("abc") for _, rawQ := range []string{``, ` `, `{}`, `[]`, `[{}]`, `[""]`, `""`, `" "`, `{"method": " "}`} { t.Run(rawQ, func(t *testing.T) { r := c.Call([]byte(rawQ)) @@ -106,7 +88,7 @@ func TestCallerСall(t *testing.T) { } func TestCallerSetWalletID(t *testing.T) { - svc := NewService(Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}) + svc := NewService(sdkrouter.New(config.GetLbrynetServers())) c := svc.NewCaller("abc") assert.Equal(t, "abc", c.walletID) } @@ -117,14 +99,13 @@ func TestCallerCallResolve(t *testing.T) { resolveResponse ljsonrpc.ResolveResponse ) - svc := NewService(Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}) - c := svc.NewCaller("") + svc := NewService(sdkrouter.New(config.GetLbrynetServers())) resolvedURL := "what#6769855a9aa43b67086f9ff3c1a5bacb5698a27a" resolvedClaimID := "6769855a9aa43b67086f9ff3c1a5bacb5698a27a" request := newRawRequest(t, "resolve", map[string]string{"urls": resolvedURL}) - rawCallReponse := c.Call(request) + rawCallReponse := svc.NewCaller("").Call(request) err := json.Unmarshal(rawCallReponse, &errorResponse) require.NoError(t, err) require.Nil(t, errorResponse.Error) @@ -140,90 +121,89 @@ func TestCallerCallWalletBalance(t *testing.T) { dummyUserID := rand.Intn(10^6-10^3) + 10 ^ 3 rt := sdkrouter.New(config.GetLbrynetServers()) - wid, err := lbrynet.InitializeWallet(rt, dummyUserID) + walletID, err := rt.InitializeWallet(dummyUserID) require.NoError(t, err) - svc := NewService(Opts{SDKRouter: rt}) + svc := NewService(rt) request := newRawRequest(t, "wallet_balance", nil) - c := svc.NewCaller("") - result := c.Call(request) - assert.Contains(t, string(result), `"message": "account identificator required"`) + result := svc.NewCaller("").Call(request) + assert.Contains(t, string(result), `"message": "account identifier required"`) - c = svc.NewCaller(wid) hook := logrusTest.NewLocal(Logger.Logger()) - result = c.Call(request) + result = svc.NewCaller(walletID).Call(request) parseRawResponse(t, result, &accountBalanceResponse) assert.EqualValues(t, "0", fmt.Sprintf("%v", accountBalanceResponse.Available)) - assert.Equal(t, map[string]interface{}{"wallet_id": fmt.Sprintf("%v", wid)}, hook.LastEntry().Data["params"]) + assert.Equal(t, map[string]interface{}{"wallet_id": fmt.Sprintf("%v", walletID)}, hook.LastEntry().Data["params"]) assert.Equal(t, "wallet_balance", hook.LastEntry().Data["method"]) } func TestCallerCallRelaxedMethods(t *testing.T) { + reqChan := make(chan *test.RequestData, 1) + srv := test.MockHTTPServer(reqChan) + defer srv.Close() + srv.NoMoreResponses() + caller := &Caller{ + client: NewClient(srv.URL, "", time.Second), + service: NewService(sdkrouter.New(config.GetLbrynetServers())), + } + for _, m := range relaxedMethods { t.Run(m, func(t *testing.T) { if m == MethodStatus { return } - mockClient := &MockClient{} - svc := NewService(Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}) - c := Caller{ - client: mockClient, - service: svc, - } - request := newRawRequest(t, m, nil) - result := c.Call(request) - expectedRequest := jsonrpc.RPCRequest{ + caller.Call(newRawRequest(t, m, nil)) + receivedRequest := <-reqChan + expectedRequest := test.ReqToStr(t, jsonrpc.RPCRequest{ Method: m, Params: nil, JSONRPC: "2.0", - } - assert.EqualValues(t, expectedRequest, mockClient.LastRequest, string(result)) + }) + assert.EqualValues(t, expectedRequest, receivedRequest.Body) }) } } func TestCallerCallNonRelaxedMethods(t *testing.T) { + caller := &Caller{ + client: NewClient("", "", 0), + service: NewService(sdkrouter.New(config.GetLbrynetServers())), + } for _, m := range walletSpecificMethods { - mockClient := &MockClient{} - svc := NewService(Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}) - c := Caller{ - client: mockClient, - service: svc, - } - request := newRawRequest(t, m, nil) - result := c.Call(request) - assert.Contains(t, string(result), `"message": "account identificator required"`) + result := caller.Call(newRawRequest(t, m, nil)) + assert.Contains(t, string(result), `"message": "account identifier required"`) } } func TestCallerCallForbiddenMethod(t *testing.T) { - mockClient := &MockClient{} - svc := NewService(Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}) - c := Caller{ - client: mockClient, - service: svc, + caller := &Caller{ + client: NewClient("", "", 0), + service: NewService(sdkrouter.New(config.GetLbrynetServers())), } - request := newRawRequest(t, "stop", nil) - result := c.Call(request) + result := caller.Call(newRawRequest(t, "stop", nil)) assert.Contains(t, string(result), `"message": "forbidden method"`) } func TestCallerCallAttachesWalletID(t *testing.T) { - mockClient := &MockClient{} - rand.Seed(time.Now().UnixNano()) dummyWalletID := "abc123321" - svc := NewService(Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}) - c := Caller{ + reqChan := make(chan *test.RequestData, 1) + srv := test.MockHTTPServer(reqChan) + defer srv.Close() + srv.NoMoreResponses() + caller := &Caller{ walletID: dummyWalletID, - client: mockClient, - service: svc, + client: NewClient(srv.URL, dummyWalletID, time.Second), + service: NewService(sdkrouter.New(config.GetLbrynetServers())), } - c.Call([]byte(newRawRequest(t, "channel_create", map[string]string{"name": "test", "bid": "0.1"}))) - expectedRequest := jsonrpc.RPCRequest{ + + caller.Call(newRawRequest(t, "channel_create", map[string]string{"name": "test", "bid": "0.1"})) + receivedRequest := <-reqChan + + expectedRequest := test.ReqToStr(t, jsonrpc.RPCRequest{ Method: "channel_create", Params: map[string]interface{}{ "name": "test", @@ -231,16 +211,18 @@ func TestCallerCallAttachesWalletID(t *testing.T) { "wallet_id": dummyWalletID, }, JSONRPC: "2.0", - } - assert.EqualValues(t, expectedRequest, mockClient.LastRequest) + }) + assert.EqualValues(t, expectedRequest, receivedRequest.Body) } func TestCallerSetPreprocessor(t *testing.T) { - svc := NewService(Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}) - client := &MockClient{} - c := Caller{ - client: client, - service: svc, + reqChan := make(chan *test.RequestData, 1) + srv := test.MockHTTPServer(reqChan) + defer srv.Close() + srv.NoMoreResponses() + c := &Caller{ + client: NewClient(srv.URL, "", time.Second), + service: NewService(sdkrouter.New(config.GetLbrynetServers())), } c.SetPreprocessor(func(q *Query) { @@ -253,15 +235,16 @@ func TestCallerSetPreprocessor(t *testing.T) { } }) - c.Call([]byte(newRawRequest(t, relaxedMethods[0], nil))) - p, ok := client.LastRequest.Params.(map[string]string) - assert.True(t, ok) - assert.Equal(t, "123", p["param"]) + c.Call(newRawRequest(t, relaxedMethods[0], nil)) + req := <-reqChan + lastRequest := test.StrToReq(t, req.Body) + + p, ok := lastRequest.Params.(map[string]interface{}) + assert.True(t, ok, req.Body) + assert.Equal(t, "123", p["param"], req.Body) } func TestCallerCallSDKError(t *testing.T) { - var rpcResponse jsonrpc.RPCResponse - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { responses.PrepareJSONWriter(w) w.Write([]byte(` @@ -293,11 +276,12 @@ func TestCallerCallSDKError(t *testing.T) { } `)) })) - svc := NewService(Opts{SDKRouter: sdkrouter.New(map[string]string{"sdk": ts.URL})}) - c := svc.NewCaller("") + svc := NewService(sdkrouter.New(map[string]string{"sdk": ts.URL})) + c := svc.NewCaller("") hook := logrusTest.NewLocal(Logger.Logger()) - response := c.Call([]byte(newRawRequest(t, "resolve", map[string]string{"urls": "what"}))) + response := c.Call(newRawRequest(t, "resolve", map[string]string{"urls": "what"})) + var rpcResponse jsonrpc.RPCResponse json.Unmarshal(response, &rpcResponse) assert.Equal(t, rpcResponse.Error.Code, -32500) assert.Equal(t, "proxy", hook.LastEntry().Data["module"]) @@ -305,16 +289,15 @@ func TestCallerCallSDKError(t *testing.T) { } func TestCallerCallClientJSONError(t *testing.T) { - var rpcResponse jsonrpc.RPCResponse - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { responses.PrepareJSONWriter(w) w.Write([]byte(`{"method":"version}`)) })) - svc := NewService(Opts{SDKRouter: sdkrouter.New(map[string]string{"sdk": ts.URL})}) + svc := NewService(sdkrouter.New(map[string]string{"sdk": ts.URL})) c := svc.NewCaller("") response := c.Call([]byte(`{"method":"version}`)) + var rpcResponse jsonrpc.RPCResponse json.Unmarshal(response, &rpcResponse) assert.Equal(t, "2.0", rpcResponse.JSONRPC) assert.Equal(t, ErrJSONParse, rpcResponse.Error.Code) @@ -322,18 +305,20 @@ func TestCallerCallClientJSONError(t *testing.T) { } func TestQueryParamsAsMap(t *testing.T) { - var q *Query - - q, _ = NewQuery(newRawRequest(t, "version", nil)) + q, err := NewQuery(newRawRequest(t, "version", nil)) + require.NoError(t, err) assert.Nil(t, q.ParamsAsMap()) - q, _ = NewQuery(newRawRequest(t, "resolve", map[string]string{"urls": "what"})) + q, err = NewQuery(newRawRequest(t, "resolve", map[string]string{"urls": "what"})) + require.NoError(t, err) assert.Equal(t, map[string]interface{}{"urls": "what"}, q.ParamsAsMap()) - q, _ = NewQuery(newRawRequest(t, "account_balance", nil)) + q, err = NewQuery(newRawRequest(t, "account_balance", nil)) + require.NoError(t, err) + q.SetWalletID("123") - err := q.validate() - require.Nil(t, err, errors.Unwrap(err)) + err = q.validate() + require.NoError(t, err, errors.Unwrap(err)) assert.Equal(t, map[string]interface{}{"wallet_id": "123"}, q.ParamsAsMap()) searchParams := map[string]interface{}{ @@ -342,18 +327,18 @@ func TestQueryParamsAsMap(t *testing.T) { "gaming", "music", "news", "science", "sports", "technology", }, } - q, _ = NewQuery(newRawRequest(t, "claim_search", searchParams)) + q, err = NewQuery(newRawRequest(t, "claim_search", searchParams)) + require.NoError(t, err) assert.Equal(t, searchParams, q.ParamsAsMap()) } func TestSDKMethodStatus(t *testing.T) { - var rpcResponse jsonrpc.RPCResponse - - svc := NewService(Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}) + svc := NewService(sdkrouter.New(config.GetLbrynetServers())) c := svc.NewCaller("") request := newRawRequest(t, "status", nil) callResult := c.Call(request) + var rpcResponse jsonrpc.RPCResponse json.Unmarshal(callResult, &rpcResponse) result := rpcResponse.Result.(map[string]interface{}) assert.Equal(t, diff --git a/app/proxy/query.go b/app/proxy/query.go new file mode 100644 index 00000000..a6df17b5 --- /dev/null +++ b/app/proxy/query.go @@ -0,0 +1,154 @@ +package proxy + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + + ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" + "github.com/lbryio/lbrytv/internal/monitor" + "github.com/ybbus/jsonrpc" +) + +// Query is a wrapper around client JSON-RPC query for easier (un)marshaling and processing. +type Query struct { + Request *jsonrpc.RPCRequest + rawRequest []byte + walletID string +} + +// NewQuery initializes Query object with JSON-RPC request supplied as bytes. +// The object is immediately usable and returns an error in case request parsing fails. +func NewQuery(r []byte) (*Query, error) { + q := &Query{rawRequest: r, Request: &jsonrpc.RPCRequest{}} + err := q.unmarshal() + if err != nil { + return nil, err + } + return q, nil +} + +func (q *Query) unmarshal() error { + err := json.Unmarshal(q.rawRequest, q.Request) + if err != nil { + return err + } + if strings.TrimSpace(q.Request.Method) == "" { + return errors.New("invalid JSON-RPC request") + } + return nil +} + +func (q *Query) validate() CallError { + if !methodInList(q.Method(), relaxedMethods) && !methodInList(q.Method(), walletSpecificMethods) { + return NewMethodError(errors.New("forbidden method")) + } + + if q.ParamsAsMap() != nil { + if _, ok := q.ParamsAsMap()[forbiddenParam]; ok { + return NewParamsError(fmt.Errorf("forbidden parameter supplied: %v", forbiddenParam)) + } + } + + if !methodInList(q.Method(), relaxedMethods) { + if q.walletID == "" { + return NewParamsError(errors.New("account identifier required")) + } + if p := q.ParamsAsMap(); p != nil { + p[paramWalletID] = q.walletID + q.Request.Params = p + } else { + q.Request.Params = map[string]interface{}{paramWalletID: q.walletID} + } + } + + return nil +} + +// Method is a shortcut for query method. +func (q *Query) Method() string { + return q.Request.Method +} + +// Params is a shortcut for query params. +func (q *Query) Params() interface{} { + return q.Request.Params +} + +// ParamsAsMap returns query params converted to plain map. +func (q *Query) ParamsAsMap() map[string]interface{} { + if paramsMap, ok := q.Params().(map[string]interface{}); ok { + return paramsMap + } + return nil +} + +// ParamsToStruct returns query params parsed into a supplied structure. +func (q *Query) ParamsToStruct(targetStruct interface{}) error { + return ljsonrpc.Decode(q.Params(), targetStruct) +} + +// cacheHit returns true if we got a resolve query with more than `cacheResolveLongerThan` urls in it. +func (q *Query) isCacheable() bool { + if q.Method() == MethodResolve && q.Params() != nil { + paramsMap := q.Params().(map[string]interface{}) + if urls, ok := paramsMap[paramUrls].([]interface{}); ok { + if len(urls) > cacheResolveLongerThan { + return true + } + } + } else if q.Method() == MethodClaimSearch { + return true + } + return false +} + +func (q *Query) newResponse() *jsonrpc.RPCResponse { + return &jsonrpc.RPCResponse{ + JSONRPC: q.Request.JSONRPC, + ID: q.Request.ID, + } +} + +func (q *Query) SetWalletID(id string) { + q.walletID = id +} + +// cacheHit returns cached response or nil in case it's a miss or query shouldn't be cacheable. +func (q *Query) cacheHit() *jsonrpc.RPCResponse { + if !q.isCacheable() { + return nil + } + + cached := responseCache.Retrieve(q.Method(), q.Params()) + if cached == nil { + return nil + } + + s, err := json.Marshal(cached) + if err != nil { + Logger.Errorf("error marshalling cached response") + return nil + } + + response := q.newResponse() + err = json.Unmarshal(s, &response) + if err != nil { + return nil + } + + monitor.LogCachedQuery(q.Method()) + return response +} + +func (q *Query) predefinedResponse() *jsonrpc.RPCResponse { + switch q.Method() { + case MethodStatus: + response := q.newResponse() + response.Result = getStatusResponse() + return response + default: + return nil + } +} diff --git a/app/publish/handler_test.go b/app/publish/handler_test.go index 6b0facfa..a39c7f06 100644 --- a/app/publish/handler_test.go +++ b/app/publish/handler_test.go @@ -126,6 +126,6 @@ func TestUploadHandlerSystemError(t *testing.T) { func TestNewUploadHandler(t *testing.T) { h, err := NewUploadHandler(UploadOpts{}) - assert.Error(t, err, "need either a ProxyService or a Publisher instance") + assert.Error(t, err, "need either a Service or a Publisher instance") assert.Nil(t, h) } diff --git a/app/publish/publish.go b/app/publish/publish.go index f8a4ce87..6956eda7 100644 --- a/app/publish/publish.go +++ b/app/publish/publish.go @@ -35,7 +35,7 @@ type Publisher interface { // LbrynetPublisher is an implementation of SDK publisher. type LbrynetPublisher struct { - *proxy.ProxyService + *proxy.Service } // UploadHandler glues HTTP uploads to the Publisher. @@ -47,7 +47,7 @@ type UploadHandler struct { type UploadOpts struct { Path string Publisher Publisher - ProxyService *proxy.ProxyService + ProxyService *proxy.Service } // NewUploadHandler returns a HTTP upload handler object. @@ -57,11 +57,11 @@ func NewUploadHandler(opts UploadOpts) (*UploadHandler, error) { uploadPath string ) if opts.ProxyService != nil { - publisher = &LbrynetPublisher{ProxyService: opts.ProxyService} + publisher = &LbrynetPublisher{Service: opts.ProxyService} } else if opts.Publisher != nil { publisher = opts.Publisher } else { - return nil, errors.New("need either a ProxyService or a Publisher instance") + return nil, errors.New("need either a Service or a Publisher instance") } if opts.Path == "" { @@ -79,7 +79,7 @@ func NewUploadHandler(opts UploadOpts) (*UploadHandler, error) { // patches the query and sends it to the SDK for processing. // Resulting response is then returned back as a slice of bytes. func (p *LbrynetPublisher) Publish(filePath, walletID string, rawQuery []byte) []byte { - c := p.ProxyService.NewCaller(walletID) + c := p.Service.NewCaller(walletID) c.SetPreprocessor(func(q *proxy.Query) { params := q.ParamsAsMap() params[fileNameParam] = filePath diff --git a/app/publish/publish_test.go b/app/publish/publish_test.go index cec2bce1..36fc11e7 100644 --- a/app/publish/publish_test.go +++ b/app/publish/publish_test.go @@ -45,7 +45,7 @@ func TestLbrynetPublisher(t *testing.T) { defer config.RestoreOverridden() rt := sdkrouter.New(config.GetLbrynetServers()) - p := &LbrynetPublisher{proxy.NewService(proxy.Opts{SDKRouter: rt})} + p := &LbrynetPublisher{proxy.NewService(rt)} walletSvc := users.NewWalletService(rt) u, err := walletSvc.Retrieve(users.Query{Token: authToken}) require.NoError(t, err) diff --git a/app/sdkrouter/concurrency_test.go b/app/sdkrouter/concurrency_test.go index b898b4db..c0d524da 100644 --- a/app/sdkrouter/concurrency_test.go +++ b/app/sdkrouter/concurrency_test.go @@ -14,9 +14,13 @@ import ( ) func TestRouterConcurrency(t *testing.T) { - rpcServer, nextResp := test.MockJSONRPCServer(nil) + rpcServer := test.MockHTTPServer(nil) defer rpcServer.Close() - nextResp(`{"result": {"items": [], "page": 1, "page_size": 1, "total_pages": 10}}`) // mock WalletList response + go func() { + for { + rpcServer.NextResponse <- `{"result": {"items": [], "page": 1, "page_size": 1, "total_pages": 10}}` // mock WalletList response + } + }() r := New(map[string]string{"srv": rpcServer.URL}) servers := r.servers @@ -43,13 +47,13 @@ func TestRouterConcurrency(t *testing.T) { case 0: r.RandomServer() r.GetAll() - r.GetServer("yutwns.123.wallet") + r.GetServer(123) case 1: r.GetAll() - r.GetServer("yutwns.123.wallet") + r.GetServer(123) r.RandomServer() case 2: - r.GetServer("yutwns.123.wallet") + r.GetServer(123) r.RandomServer() r.GetAll() } diff --git a/app/sdkrouter/sdkrouter.go b/app/sdkrouter/sdkrouter.go index ae384833..4b1ede9b 100644 --- a/app/sdkrouter/sdkrouter.go +++ b/app/sdkrouter/sdkrouter.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/lbryio/lbrytv/internal/lbrynet" "github.com/lbryio/lbrytv/internal/metrics" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/models" @@ -23,6 +24,8 @@ import ( var logger = monitor.NewModuleLogger("sdkrouter") +func DisableLogger() { logger.Disable() } // for testing + type Router struct { mu sync.RWMutex servers []*models.LbrynetServer @@ -61,26 +64,25 @@ func (r *Router) GetAll() []*models.LbrynetServer { return r.servers } -func (r *Router) GetServer(walletID string) *models.LbrynetServer { +func (r *Router) GetServer(userID int) *models.LbrynetServer { r.reloadServersFromDB() var sdk *models.LbrynetServer - if walletID == "" { + if userID == 0 { sdk = r.LeastLoaded() } else { - sdk = r.serverForWallet(walletID) + sdk = r.serverForUser(userID) if sdk.Address == "" { - logger.Log().Errorf("wallet [%s] is set but there is no server associated with it.", walletID) + logger.Log().Errorf("user %d has server but server has no address.", userID) sdk = r.RandomServer() } } - logger.Log().Tracef("Using [%s] server for wallet [%s]", sdk.Address, walletID) + logger.Log().Tracef("Using server %s for user %d", sdk.Address, userID) return sdk } -func (r *Router) serverForWallet(walletID string) *models.LbrynetServer { - userID := getUserID(walletID) +func (r *Router) serverForUser(userID int) *models.LbrynetServer { var user *models.User var err error if boil.GetDB() != nil { @@ -95,19 +97,19 @@ func (r *Router) serverForWallet(walletID string) *models.LbrynetServer { if user == nil || user.R == nil || user.R.LbrynetServer == nil { srv := r.servers[getServerForUserID(userID, len(r.servers))] - logger.Log().Debugf("User %d has no wallet in db. Giving them %s", userID, srv.Address) + logger.Log().Debugf("User %d has no server assigned in db. Giving them server %s", userID, srv.Address) return srv } for _, s := range r.servers { if s.ID == user.R.LbrynetServer.ID { - logger.Log().Debugf("User %d has wallet %s set in db", userID, s.Address) + logger.Log().Debugf("User %d has server %s assigned in db", userID, s.Address) return s } } srv := r.servers[getServerForUserID(userID, len(r.servers))] - logger.Log().Errorf("Server for user %d is set in db but is not in current servers list. Giving them %s", userID, srv.Address) + logger.Log().Errorf("User %d has server assigned in db but is not in current servers list. Giving them server %s", userID, srv.Address) return srv } @@ -209,25 +211,105 @@ func (r *Router) LeastLoaded() *models.LbrynetServer { return best } -func getUserID(walletID string) int { - userID, err := strconv.ParseInt(regexp.MustCompile(`\d+`).FindString(walletID), 10, 64) +func (r *Router) Client(userID int) *ljsonrpc.Client { + c := ljsonrpc.NewClient(r.GetServer(userID).Address) + //c.SetRPCTimeout(5 * time.Second) + return c +} + +// InitializeWallet creates a wallet that can be immediately used in subsequent commands. +// It can recover from errors like existing wallets, but if a wallet is known to exist +// (eg. a wallet ID stored in the database already), loadWallet() should be called instead. +func (r *Router) InitializeWallet(userID int) (string, error) { + wallet, err := r.createWallet(userID) + if err == nil { + return wallet.ID, nil + } + + walletID := WalletID(userID) + log := logger.LogF(monitor.F{"user_id": userID}) + + if errors.Is(err, lbrynet.ErrWalletExists) { + log.Warn(err.Error()) + return walletID, nil + } + + if errors.Is(err, lbrynet.ErrWalletNeedsLoading) { + log.Info(err.Error()) + wallet, err = r.loadWallet(userID) + if err != nil { + if errors.Is(err, lbrynet.ErrWalletAlreadyLoaded) { + log.Info(err.Error()) + return walletID, nil + } + return "", err + } + return wallet.ID, nil + } + + log.Errorf("don't know how to recover from error: %v", err) + return "", err +} + +// createWallet creates a new wallet on the LbrynetServer. +// Returned error doesn't necessarily mean that the wallet is not operational: +// +// if errors.Is(err, lbrynet.WalletExists) { +// // Okay to proceed with the account +// } +// +// if errors.Is(err, lbrynet.WalletNeedsLoading) { +// // loadWallet() needs to be called before the wallet can be used +// } +func (r *Router) createWallet(userID int) (*ljsonrpc.Wallet, error) { + wallet, err := r.Client(userID).WalletCreate(WalletID(userID), &ljsonrpc.WalletCreateOpts{ + SkipOnStartup: true, CreateAccount: true, SingleKey: true}) if err != nil { - return 0 + return nil, lbrynet.NewWalletError(userID, err) } - return int(userID) + logger.LogF(monitor.F{"user_id": userID}).Info("wallet created") + return wallet, nil } -func getServerForUserID(userID, numServers int) int { - return userID % numServers +// loadWallet loads an existing wallet in the LbrynetServer. +// May return errors: +// WalletAlreadyLoaded - wallet is already loaded and operational +// WalletNotFound - wallet file does not exist and won't be loaded. +func (r *Router) loadWallet(userID int) (*ljsonrpc.Wallet, error) { + wallet, err := r.Client(userID).WalletAdd(WalletID(userID)) + if err != nil { + return nil, lbrynet.NewWalletError(userID, err) + } + logger.LogF(monitor.F{"user_id": userID}).Info("wallet loaded") + return wallet, nil } -func (r *Router) Client(userID int) *ljsonrpc.Client { - c := ljsonrpc.NewClient(r.GetServer(WalletID(userID)).Address) - //c.SetRPCTimeout(5 * time.Second) - return c +// UnloadWallet unloads an existing wallet from the LbrynetServer. +// May return errors: +// WalletAlreadyLoaded - wallet is already loaded and operational +// WalletNotFound - wallet file does not exist and won't be loaded. +func (r *Router) UnloadWallet(userID int) error { + _, err := r.Client(userID).WalletRemove(WalletID(userID)) + if err != nil { + return lbrynet.NewWalletError(userID, err) + } + logger.LogF(monitor.F{"user_id": userID}).Info("wallet unloaded") + return nil } // WalletID formats user ID to use as an LbrynetServer wallet ID. func WalletID(userID int) string { return fmt.Sprintf("lbrytv-id.%d.wallet", userID) } + +func UserID(walletID string) int { + userID, err := strconv.ParseInt(regexp.MustCompile(`\d+`).FindString(walletID), 10, 64) + if err != nil { + return 0 + } + return int(userID) +} + +func getServerForUserID(userID, numServers int) int { + return userID % numServers +} diff --git a/app/sdkrouter/sdkrouter_test.go b/app/sdkrouter/sdkrouter_test.go index 6479b2bc..da4968d0 100644 --- a/app/sdkrouter/sdkrouter_test.go +++ b/app/sdkrouter/sdkrouter_test.go @@ -1,14 +1,20 @@ package sdkrouter import ( + "errors" "fmt" + "math/rand" "os" "testing" + "time" "github.com/lbryio/lbrytv/config" + "github.com/lbryio/lbrytv/internal/lbrynet" "github.com/lbryio/lbrytv/internal/storage" "github.com/lbryio/lbrytv/internal/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { @@ -25,8 +31,8 @@ func TestMain(m *testing.M) { } func TestInitializeWithYML(t *testing.T) { - sdkRouter := New(config.GetLbrynetServers()) - assert.True(t, len(sdkRouter.GetAll()) > 0, "No servers") + r := New(config.GetLbrynetServers()) + assert.True(t, len(r.GetAll()) > 0, "No servers") } func TestServerOrder(t *testing.T) { @@ -37,10 +43,10 @@ func TestServerOrder(t *testing.T) { "d": "3", "c": "2", } - sdkRouter := New(servers) + r := New(servers) - for i := 0; i < 100; i++ { - server := sdkRouter.GetServer(WalletID(i)).Address + for i := 1; i < 100; i++ { + server := r.GetServer(i).Address assert.Equal(t, fmt.Sprintf("%d", i%len(servers)), server) } } @@ -49,7 +55,7 @@ func TestOverrideLbrynetDefaultConf(t *testing.T) { address := "http://space.com:1234" config.Override("LbrynetServers", map[string]string{"x": address}) defer config.RestoreOverridden() - server := New(config.GetLbrynetServers()).GetServer(WalletID(343465345)) + server := New(config.GetLbrynetServers()).GetServer(343465345) assert.Equal(t, address, server.Address) } @@ -58,18 +64,16 @@ func TestOverrideLbrynetConf(t *testing.T) { config.Override("Lbrynet", address) config.Override("LbrynetServers", map[string]string{}) defer config.RestoreOverridden() - server := New(config.GetLbrynetServers()).GetServer(WalletID(1343465345)) + server := New(config.GetLbrynetServers()).GetServer(1343465345) assert.Equal(t, address, server.Address) } func TestGetUserID(t *testing.T) { - userID := getUserID("sjdfkjhsdkjs.1234235.sdfsgf") - assert.Equal(t, 1234235, userID) + assert.Equal(t, 1234235, UserID("sjdfkjhsdkjs.1234235.sdfsgf")) } func TestLeastLoaded(t *testing.T) { - reqChan := make(chan *test.RequestData, 1) - rpcServer, nextResp := test.MockJSONRPCServer(reqChan) + rpcServer := test.MockHTTPServer(nil) defer rpcServer.Close() servers := map[string]string{ @@ -82,8 +86,7 @@ func TestLeastLoaded(t *testing.T) { // try doing the load in increasing order go func() { for i := 0; i < len(servers); i++ { - nextResp(fmt.Sprintf(`{"result":{"total_pages":%d}}`, i)) - <-reqChan + rpcServer.NextResponse <- fmt.Sprintf(`{"result":{"total_pages":%d}}`, i) } }() r.updateLoadAndMetrics() @@ -92,11 +95,48 @@ func TestLeastLoaded(t *testing.T) { // now do the load in decreasing order go func() { for i := 0; i < len(servers); i++ { - nextResp(fmt.Sprintf(`{"result":{"total_pages":%d}}`, len(servers)-i)) - <-reqChan + rpcServer.NextResponse <- fmt.Sprintf(`{"result":{"total_pages":%d}}`, len(servers)-i) } }() r.updateLoadAndMetrics() assert.Equal(t, "srv3", r.LeastLoaded().Name) } + +func TestInitializeWallet(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + userID := rand.Int() + r := New(config.GetLbrynetServers()) + + walletID, err := r.InitializeWallet(userID) + require.NoError(t, err) + assert.Equal(t, walletID, WalletID(userID)) + + err = r.UnloadWallet(userID) + require.NoError(t, err) + + walletID, err = r.InitializeWallet(userID) + require.NoError(t, err) + assert.Equal(t, walletID, WalletID(userID)) +} + +func TestCreateWalletLoadWallet(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + userID := rand.Int() + r := New(config.GetLbrynetServers()) + + wallet, err := r.createWallet(userID) + require.NoError(t, err) + assert.Equal(t, wallet.ID, WalletID(userID)) + + wallet, err = r.createWallet(userID) + require.NotNil(t, err) + assert.True(t, errors.Is(err, lbrynet.ErrWalletExists)) + + err = r.UnloadWallet(userID) + require.NoError(t, err) + + wallet, err = r.loadWallet(userID) + require.NoError(t, err) + assert.Equal(t, wallet.ID, WalletID(userID)) +} diff --git a/app/users/users.go b/app/users/users.go index 9784011c..4238237a 100644 --- a/app/users/users.go +++ b/app/users/users.go @@ -5,7 +5,6 @@ import ( "fmt" "github.com/lbryio/lbrytv/app/sdkrouter" - "github.com/lbryio/lbrytv/internal/lbrynet" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/models" @@ -105,10 +104,9 @@ func (s *WalletService) Retrieve(q Query) (*models.User, error) { return localUser, nil } -// TODO: this is the function where users are assigned to SDKs. assign them randomly func createWalletForUser(user *models.User, router *sdkrouter.Router, log *logrus.Entry) error { // either a new user or a legacy user without a wallet - walletID, err := lbrynet.InitializeWallet(router, user.ID) + walletID, err := router.InitializeWallet(user.ID) if err != nil { return err } @@ -118,7 +116,7 @@ func createWalletForUser(user *models.User, router *sdkrouter.Router, log *logru user.WalletID = walletID - server := router.GetServer(sdkrouter.WalletID(user.ID)) + server := router.GetServer(user.ID) if server.ID > 0 { // Ensure server is from DB user.LbrynetServerID.SetValid(server.ID) } diff --git a/app/users/users_test.go b/app/users/users_test.go index 7ee64d27..bdbf27af 100644 --- a/app/users/users_test.go +++ b/app/users/users_test.go @@ -11,7 +11,6 @@ import ( "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/config" - "github.com/lbryio/lbrytv/internal/lbrynet" "github.com/lbryio/lbrytv/internal/storage" "github.com/lbryio/lbrytv/models" log "github.com/sirupsen/logrus" @@ -56,7 +55,7 @@ func setupCleanupDummyUser(rt *sdkrouter.Router, uidParam ...int) func() { return func() { ts.Close() config.RestoreOverridden() - lbrynet.UnloadWallet(rt, uid) + rt.UnloadWallet(uid) } } @@ -167,7 +166,7 @@ func BenchmarkWalletCommands(b *testing.B) { cl := jsonrpc.NewClient(sdkRouter.RandomServer().Address) svc.Logger.Disable() - lbrynet.Logger.Disable() + sdkrouter.DisableLogger() log.SetOutput(ioutil.Discard) rand.Seed(time.Now().UnixNano()) diff --git a/cmd/serve.go b/cmd/serve.go index 3dee490b..b42239c8 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -25,7 +25,7 @@ var rootCmd = &cobra.Command{ s := server.NewServer(server.Options{ Address: config.GetAddress(), - ProxyService: proxy.NewService(proxy.Opts{SDKRouter: sdkRouter}), + ProxyService: proxy.NewService(sdkRouter), }) err := s.Start() if err != nil { diff --git a/internal/environment/environment.go b/internal/environment/environment.go index 3486ec1c..e0fae9b1 100644 --- a/internal/environment/environment.go +++ b/internal/environment/environment.go @@ -10,10 +10,10 @@ type Env struct { *monitor.ModuleLogger *config.ConfigWrapper - proxy *proxy.ProxyService + proxy *proxy.Service } -func NewEnvironment(logger *monitor.ModuleLogger, config *config.ConfigWrapper, ps *proxy.ProxyService) *Env { +func NewEnvironment(logger *monitor.ModuleLogger, config *config.ConfigWrapper, ps *proxy.Service) *Env { if logger == nil { logger = &monitor.ModuleLogger{} } diff --git a/internal/lbrynet/lbrynet.go b/internal/lbrynet/lbrynet.go deleted file mode 100644 index 6795812d..00000000 --- a/internal/lbrynet/lbrynet.go +++ /dev/null @@ -1,92 +0,0 @@ -package lbrynet - -import ( - "errors" - - "github.com/lbryio/lbrytv/app/sdkrouter" - "github.com/lbryio/lbrytv/internal/monitor" - - ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" -) - -var Logger = monitor.NewModuleLogger("lbrynet") - -// InitializeWallet creates a wallet that can be immediately used in subsequent commands. -// It can recover from errors like existing wallets, but if a wallet is known to exist -// (eg. a wallet ID stored in the database already), loadWallet() should be called instead. -func InitializeWallet(rt *sdkrouter.Router, userID int) (string, error) { - w, err := createWallet(rt, userID) - if err == nil { - return w.ID, nil - } - - walletID := sdkrouter.WalletID(userID) - log := Logger.LogF(monitor.F{"user_id": userID}) - - if errors.Is(err, ErrWalletExists) { - log.Warn(err.Error()) - return walletID, nil - } - - if errors.Is(err, ErrWalletNeedsLoading) { - log.Info(err.Error()) - w, err = loadWallet(rt, userID) - if err != nil { - if errors.Is(err, ErrWalletAlreadyLoaded) { - log.Info(err.Error()) - return walletID, nil - } - return "", err - } - return w.ID, nil - } - - log.Errorf("don't know how to recover from error: %v", err) - return "", err -} - -// createWallet creates a new wallet on the LbrynetServer. -// Returned error doesn't necessarily mean that the wallet is not operational: -// -// if errors.Is(err, lbrynet.WalletExists) { -// // Okay to proceed with the account -// } -// -// if errors.Is(err, lbrynet.WalletNeedsLoading) { -// // loadWallet() needs to be called before the wallet can be used -// } -func createWallet(rt *sdkrouter.Router, userID int) (*ljsonrpc.Wallet, error) { - wallet, err := rt.Client(userID).WalletCreate(sdkrouter.WalletID(userID), &ljsonrpc.WalletCreateOpts{ - SkipOnStartup: true, CreateAccount: true, SingleKey: true}) - if err != nil { - return nil, NewWalletError(userID, err) - } - Logger.LogF(monitor.F{"user_id": userID}).Info("wallet created") - return wallet, nil -} - -// loadWallet loads an existing wallet in the LbrynetServer. -// May return errors: -// WalletAlreadyLoaded - wallet is already loaded and operational -// WalletNotFound - wallet file does not exist and won't be loaded. -func loadWallet(rt *sdkrouter.Router, userID int) (*ljsonrpc.Wallet, error) { - wallet, err := rt.Client(userID).WalletAdd(sdkrouter.WalletID(userID)) - if err != nil { - return nil, NewWalletError(userID, err) - } - Logger.LogF(monitor.F{"user_id": userID}).Info("wallet loaded") - return wallet, nil -} - -// UnloadWallet unloads an existing wallet from the LbrynetServer. -// May return errors: -// WalletAlreadyLoaded - wallet is already loaded and operational -// WalletNotFound - wallet file does not exist and won't be loaded. -func UnloadWallet(rt *sdkrouter.Router, userID int) error { - _, err := rt.Client(userID).WalletRemove(sdkrouter.WalletID(userID)) - if err != nil { - return NewWalletError(userID, err) - } - Logger.LogF(monitor.F{"user_id": userID}).Info("wallet unloaded") - return nil -} diff --git a/internal/lbrynet/lbrynet_test.go b/internal/lbrynet/lbrynet_test.go deleted file mode 100644 index 473c198d..00000000 --- a/internal/lbrynet/lbrynet_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package lbrynet - -import ( - "errors" - "math/rand" - "os" - "testing" - "time" - - "github.com/lbryio/lbrytv/app/sdkrouter" - "github.com/lbryio/lbrytv/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMain(m *testing.M) { - rand.Seed(time.Now().UnixNano()) - code := m.Run() - os.Exit(code) -} - -func TestInitializeWallet(t *testing.T) { - uid := rand.Int() - r := sdkrouter.New(config.GetLbrynetServers()) - - wid, err := InitializeWallet(r, uid) - require.NoError(t, err) - assert.Equal(t, wid, sdkrouter.WalletID(uid)) - - err = UnloadWallet(r, uid) - require.NoError(t, err) - - wid, err = InitializeWallet(r, uid) - require.NoError(t, err) - assert.Equal(t, wid, sdkrouter.WalletID(uid)) -} - -func TestCreateWalletLoadWallet(t *testing.T) { - uid := rand.Int() - r := sdkrouter.New(config.GetLbrynetServers()) - - w, err := createWallet(r, uid) - require.NoError(t, err) - assert.Equal(t, w.ID, sdkrouter.WalletID(uid)) - - _, err = createWallet(r, uid) - require.NotNil(t, err) - assert.True(t, errors.Is(err, ErrWalletExists)) - - err = UnloadWallet(r, uid) - require.NoError(t, err) - - w, err = loadWallet(r, uid) - require.NoError(t, err) - assert.Equal(t, w.ID, sdkrouter.WalletID(uid)) -} diff --git a/internal/metrics/routes_test.go b/internal/metrics/routes_test.go index b8717c75..902fdc2d 100644 --- a/internal/metrics/routes_test.go +++ b/internal/metrics/routes_test.go @@ -57,7 +57,7 @@ func testMetricUIEvent(t *testing.T, method, name, value string) *httptest.Respo req.URL.RawQuery = q.Encode() r := mux.NewRouter() - api.InstallRoutes(proxy.NewService(proxy.Opts{}), r) + api.InstallRoutes(proxy.NewService(nil), r) rr := httptest.NewRecorder() r.ServeHTTP(rr, req) return rr diff --git a/internal/monitor/sentry.go b/internal/monitor/sentry.go index 404b63f2..df1ad30e 100644 --- a/internal/monitor/sentry.go +++ b/internal/monitor/sentry.go @@ -9,7 +9,7 @@ import ( ) var IgnoredExceptions = []string{ - "account identificator required", + "account identifier required", } func configureSentry(release, env string) { diff --git a/internal/test/test.go b/internal/test/test.go index fe2ab0a0..3abb0ee9 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -1,41 +1,69 @@ package test import ( + "encoding/json" "fmt" "io/ioutil" "net/http" "net/http/httptest" - "sync" + "testing" + + "github.com/ybbus/jsonrpc" ) +type MockServer struct { + *httptest.Server + NextResponse chan<- string +} + +func (m *MockServer) NoMoreResponses() { + close(m.NextResponse) +} + type RequestData struct { - Request *http.Request - Body string + R *http.Request + Body string } -// MockJSONRPCServer creates a JSONRPC server that can be used to test clients +// MockHTTPServer creates an http server that can be used to test clients // NOTE: if you want to make sure that you get requests in your requestChan one by one, limit the // channel to a buffer size of 1. then writes to the chan will block until you read it -func MockJSONRPCServer(requestChan chan *RequestData) (*httptest.Server, func(string)) { - var mu sync.RWMutex - // needed to retrieve requests that arrived at httpServer for further investigation - presetResponse := "" - setNextResponse := func(s string) { - mu.Lock() - defer mu.Unlock() - presetResponse = s +func MockHTTPServer(requestChan chan *RequestData) *MockServer { + next := make(chan string, 1) + return &MockServer{ + NextResponse: next, + Server: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data, _ := ioutil.ReadAll(r.Body) + defer r.Body.Close() + if requestChan != nil { + requestChan <- &RequestData{r, string(data)} // store the request for inspection + } + fmt.Fprintf(w, <-next) + })), } +} - httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - data, _ := ioutil.ReadAll(r.Body) - defer r.Body.Close() - if requestChan != nil { - requestChan <- &RequestData{r, string(data)} // store the request for inspection - } - mu.RLock() - defer mu.RUnlock() - fmt.Fprintf(w, presetResponse) // respond with the preset response - })) - - return httpServer, setNextResponse +func ReqToStr(t *testing.T, req jsonrpc.RPCRequest) string { + r, err := json.Marshal(req) + if err != nil { + t.Fatal(err) + } + return string(r) +} + +func StrToReq(t *testing.T, req string) jsonrpc.RPCRequest { + var r jsonrpc.RPCRequest + err := json.Unmarshal([]byte(req), &r) + if err != nil { + t.Fatal(err) + } + return r +} + +func ResToStr(t *testing.T, res jsonrpc.RPCResponse) string { + r, err := json.Marshal(res) + if err != nil { + t.Fatal(err) + } + return string(r) } diff --git a/internal/test/test_test.go b/internal/test/test_test.go index c0552396..f93152c2 100644 --- a/internal/test/test_test.go +++ b/internal/test/test_test.go @@ -11,9 +11,9 @@ import ( func TestMockRPCServer(t *testing.T) { reqChan := make(chan *RequestData, 1) - rpcServer, nextResp := MockJSONRPCServer(reqChan) + rpcServer := MockHTTPServer(reqChan) defer rpcServer.Close() - nextResp(`{"result": {"items": [], "page": 1, "page_size": 2, "total_pages": 3}}`) + rpcServer.NextResponse <- `{"result": {"items": [], "page": 1, "page_size": 2, "total_pages": 3}}` rsp, err := ljsonrpc.NewClient(rpcServer.URL).WalletList("", 1, 2) if err != nil { @@ -21,7 +21,7 @@ func TestMockRPCServer(t *testing.T) { } req := <-reqChan // read the request for inspection - assert.Equal(t, req.Request.Method, http.MethodPost) + assert.Equal(t, req.R.Method, http.MethodPost) assert.Equal(t, req.Body, `{"method":"wallet_list","params":{"page":1,"page_size":2},"id":0,"jsonrpc":"2.0"}`) assert.Equal(t, rsp.Page, uint64(1)) diff --git a/server/server.go b/server/server.go index a1931223..8198fc7b 100644 --- a/server/server.go +++ b/server/server.go @@ -20,7 +20,7 @@ var logger = monitor.NewModuleLogger("server") // Server holds entities that can be used to control the web server type Server struct { defaultHeaders map[string]string - proxyService *proxy.ProxyService + proxyService *proxy.Service stopChan chan os.Signal stopWait time.Duration address string @@ -31,7 +31,7 @@ type Server struct { // Options holds basic web server settings. type Options struct { Address string - ProxyService *proxy.ProxyService + ProxyService *proxy.Service StopWaitSeconds int } diff --git a/server/server_test.go b/server/server_test.go index 79ac4cb3..7c8061bc 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -17,7 +17,7 @@ import ( func TestStartAndServeUntilShutdown(t *testing.T) { server := NewServer(Options{ Address: "localhost:40080", - ProxyService: proxy.NewService(proxy.Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}), + ProxyService: proxy.NewService(sdkrouter.New(config.GetLbrynetServers())), }) server.Start() go server.ServeUntilShutdown() @@ -49,7 +49,7 @@ func TestHeaders(t *testing.T) { server := NewServer(Options{ Address: "localhost:40080", - ProxyService: proxy.NewService(proxy.Opts{SDKRouter: sdkrouter.New(config.GetLbrynetServers())}), + ProxyService: proxy.NewService(sdkrouter.New(config.GetLbrynetServers())), }) server.Start() go server.ServeUntilShutdown() From d262891b19d0500566fbe29ff2305429f6c42ad2 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Sun, 12 Apr 2020 20:59:36 -0400 Subject: [PATCH 03/18] drop proxy Client, merge it into Caller --- app/proxy/{proxy.go => caller.go} | 124 ++++++++++++++------ app/proxy/{proxy_test.go => caller_test.go} | 100 ++++------------ app/proxy/client.go | 107 ----------------- app/proxy/client_test.go | 100 ++++++---------- app/proxy/query.go | 34 +++++- app/proxy/query_filters.go | 38 ------ app/proxy/query_test.go | 37 ++++++ app/proxy/service.go | 45 +++++++ internal/test/test.go | 24 ++-- internal/test/test_test.go | 12 +- 10 files changed, 279 insertions(+), 342 deletions(-) rename app/proxy/{proxy.go => caller.go} (52%) rename app/proxy/{proxy_test.go => caller_test.go} (76%) delete mode 100644 app/proxy/client.go delete mode 100644 app/proxy/query_filters.go create mode 100644 app/proxy/query_test.go create mode 100644 app/proxy/service.go diff --git a/app/proxy/proxy.go b/app/proxy/caller.go similarity index 52% rename from app/proxy/proxy.go rename to app/proxy/caller.go index 1d09a910..e9158e4a 100644 --- a/app/proxy/proxy.go +++ b/app/proxy/caller.go @@ -17,59 +17,37 @@ import ( "fmt" "time" - "github.com/lbryio/lbrytv/app/sdkrouter" + "github.com/lbryio/lbrytv/internal/lbrynet" + "github.com/lbryio/lbrytv/internal/metrics" "github.com/lbryio/lbrytv/internal/monitor" + ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" + + "github.com/sirupsen/logrus" "github.com/ybbus/jsonrpc" ) -const defaultRPCTimeout = 30 * time.Second - var Logger = monitor.NewProxyLogger() -type Preprocessor func(q *Query) - -// Service generates Caller objects and keeps execution time metrics -// for all calls proxied through those objects. -type Service struct { - SDKRouter *sdkrouter.Router - rpcTimeout time.Duration - logger monitor.QueryMonitor -} +const ( + walletLoadRetries = 3 + walletLoadRetryWait = 100 * time.Millisecond +) // Caller patches through JSON-RPC requests from clients, doing pre/post-processing, // account processing and validation. type Caller struct { + client jsonrpc.RPCClient walletID string - query *jsonrpc.RPCRequest - client Client endpoint string - service *Service - preprocessor Preprocessor -} - -// NewService is the entry point to proxy module. -// Normally only one instance of Service should be created per running server. -func NewService(router *sdkrouter.Router) *Service { - return &Service{ - SDKRouter: router, - rpcTimeout: defaultRPCTimeout, - } + preprocessor func(q *Query) } -func (ps *Service) SetRPCTimeout(timeout time.Duration) { - ps.rpcTimeout = timeout -} - -// NewCaller returns an instance of Caller ready to proxy requests. -// Note that `SetWalletID` needs to be called if an authenticated user is making this call. -func (ps *Service) NewCaller(walletID string) *Caller { - endpoint := ps.SDKRouter.GetServer(sdkrouter.UserID(walletID)).Address +func NewCaller(endpoint, walletID string) *Caller { return &Caller{ - walletID: walletID, - client: NewClient(endpoint, walletID, ps.rpcTimeout), + client: jsonrpc.NewClient(endpoint), endpoint: endpoint, - service: ps, + walletID: walletID, } } @@ -119,7 +97,7 @@ func (c *Caller) call(rawQuery []byte) (*jsonrpc.RPCResponse, CallError) { c.preprocessor(q) } - r, err := c.client.Call(q) + r, err := c.callQueryWithRetry(q) if err != nil { return r, NewInternalError(err) } @@ -135,6 +113,68 @@ func (c *Caller) call(rawQuery []byte) (*jsonrpc.RPCResponse, CallError) { return r, nil } +func (c *Caller) callQueryWithRetry(q *Query) (*jsonrpc.RPCResponse, error) { + var ( + r *jsonrpc.RPCResponse + err error + duration float64 + ) + + callMetrics := metrics.ProxyCallDurations.WithLabelValues(q.Method(), c.endpoint) + failureMetrics := metrics.ProxyCallFailedDurations.WithLabelValues(q.Method(), c.endpoint) + + for i := 0; i < walletLoadRetries; i++ { + start := time.Now() + + r, err = c.client.CallRaw(q.Request) + + duration = time.Since(start).Seconds() + callMetrics.Observe(duration) + + // Generally a HTTP transport failure (connect error etc) + if err != nil { + Logger.Errorf("error sending query to %v: %v", c.endpoint, err) + return nil, err + } + + // This checks if LbrynetServer responded with missing wallet error and tries to reload it, + // then repeats the request again. + if isErrWalletNotLoaded(r) { + time.Sleep(walletLoadRetryWait) + // Using LBRY JSON-RPC client here for easier request/response processing + client := ljsonrpc.NewClient(c.endpoint) + _, err := client.WalletAdd(c.walletID) + // Alert sentry on the last failed wallet load attempt + if err != nil && i >= walletLoadRetries-1 { + errMsg := "gave up on manually adding a wallet: %v" + Logger.Logger().WithFields(logrus.Fields{ + "wallet_id": c.walletID, + "endpoint": c.endpoint, + }).Errorf(errMsg, err) + monitor.CaptureException( + fmt.Errorf(errMsg, err), map[string]string{ + "wallet_id": c.walletID, + "endpoint": c.endpoint, + "retries": fmt.Sprintf("%v", i), + }) + } + } else if isErrWalletAlreadyLoaded(r) { + continue + } else { + break + } + } + + if (r != nil && r.Error != nil) || err != nil { + Logger.LogFailedQuery(q.Method(), c.endpoint, c.walletID, duration, q.Params(), r.Error) + failureMetrics.Observe(duration) + } else { + Logger.LogSuccessfulQuery(q.Method(), c.endpoint, c.walletID, duration, q.Params(), r) + } + + return r, err +} + func (c *Caller) marshal(r *jsonrpc.RPCResponse) ([]byte, CallError) { serialized, err := json.MarshalIndent(r, "", " ") if err != nil { @@ -152,6 +192,14 @@ func (c *Caller) marshalError(e CallError) []byte { } // SetPreprocessor applies provided function to query before it's sent to the LbrynetServer. -func (c *Caller) SetPreprocessor(p Preprocessor) { +func (c *Caller) SetPreprocessor(p func(q *Query)) { c.preprocessor = p } + +func isErrWalletNotLoaded(r *jsonrpc.RPCResponse) bool { + return r.Error != nil && errors.Is(lbrynet.NewWalletError(0, errors.New(r.Error.Message)), lbrynet.ErrWalletNotLoaded) +} + +func isErrWalletAlreadyLoaded(r *jsonrpc.RPCResponse) bool { + return r.Error != nil && errors.Is(lbrynet.NewWalletError(0, errors.New(r.Error.Message)), lbrynet.ErrWalletAlreadyLoaded) +} diff --git a/app/proxy/proxy_test.go b/app/proxy/caller_test.go similarity index 76% rename from app/proxy/proxy_test.go rename to app/proxy/caller_test.go index 899a5486..1f1b4bcf 100644 --- a/app/proxy/proxy_test.go +++ b/app/proxy/caller_test.go @@ -2,7 +2,6 @@ package proxy import ( "encoding/json" - "errors" "fmt" "math/rand" "net/http" @@ -63,9 +62,6 @@ func TestNewCaller(t *testing.T) { "second": "http://lbrynet2", } svc := NewService(sdkrouter.New(servers)) - c := svc.NewCaller("") - assert.Equal(t, svc, c.service) - sList := svc.SDKRouter.GetAll() rand.Seed(time.Now().UnixNano()) for i := 1; i <= 100; i++ { @@ -140,20 +136,17 @@ func TestCallerCallWalletBalance(t *testing.T) { } func TestCallerCallRelaxedMethods(t *testing.T) { - reqChan := make(chan *test.RequestData, 1) + reqChan := test.ReqChan() srv := test.MockHTTPServer(reqChan) defer srv.Close() - srv.NoMoreResponses() - caller := &Caller{ - client: NewClient(srv.URL, "", time.Second), - service: NewService(sdkrouter.New(config.GetLbrynetServers())), - } + caller := NewCaller(srv.URL, "") for _, m := range relaxedMethods { t.Run(m, func(t *testing.T) { if m == MethodStatus { return } + srv.NextResponse <- "" caller.Call(newRawRequest(t, m, nil)) receivedRequest := <-reqChan expectedRequest := test.ReqToStr(t, jsonrpc.RPCRequest{ @@ -167,10 +160,7 @@ func TestCallerCallRelaxedMethods(t *testing.T) { } func TestCallerCallNonRelaxedMethods(t *testing.T) { - caller := &Caller{ - client: NewClient("", "", 0), - service: NewService(sdkrouter.New(config.GetLbrynetServers())), - } + caller := NewCaller("", "") for _, m := range walletSpecificMethods { result := caller.Call(newRawRequest(t, m, nil)) assert.Contains(t, string(result), `"message": "account identifier required"`) @@ -178,10 +168,7 @@ func TestCallerCallNonRelaxedMethods(t *testing.T) { } func TestCallerCallForbiddenMethod(t *testing.T) { - caller := &Caller{ - client: NewClient("", "", 0), - service: NewService(sdkrouter.New(config.GetLbrynetServers())), - } + caller := NewCaller("", "") result := caller.Call(newRawRequest(t, "stop", nil)) assert.Contains(t, string(result), `"message": "forbidden method"`) } @@ -190,16 +177,11 @@ func TestCallerCallAttachesWalletID(t *testing.T) { rand.Seed(time.Now().UnixNano()) dummyWalletID := "abc123321" - reqChan := make(chan *test.RequestData, 1) + reqChan := test.ReqChan() srv := test.MockHTTPServer(reqChan) defer srv.Close() - srv.NoMoreResponses() - caller := &Caller{ - walletID: dummyWalletID, - client: NewClient(srv.URL, dummyWalletID, time.Second), - service: NewService(sdkrouter.New(config.GetLbrynetServers())), - } - + srv.NextResponse <- "" + caller := NewCaller(srv.URL, dummyWalletID) caller.Call(newRawRequest(t, "channel_create", map[string]string{"name": "test", "bid": "0.1"})) receivedRequest := <-reqChan @@ -216,14 +198,11 @@ func TestCallerCallAttachesWalletID(t *testing.T) { } func TestCallerSetPreprocessor(t *testing.T) { - reqChan := make(chan *test.RequestData, 1) + reqChan := test.ReqChan() srv := test.MockHTTPServer(reqChan) defer srv.Close() - srv.NoMoreResponses() - c := &Caller{ - client: NewClient(srv.URL, "", time.Second), - service: NewService(sdkrouter.New(config.GetLbrynetServers())), - } + + c := NewCaller(srv.URL, "") c.SetPreprocessor(func(q *Query) { params := q.ParamsAsMap() @@ -235,6 +214,8 @@ func TestCallerSetPreprocessor(t *testing.T) { } }) + srv.NextResponse <- "" + c.Call(newRawRequest(t, relaxedMethods[0], nil)) req := <-reqChan lastRequest := test.StrToReq(t, req.Body) @@ -245,9 +226,9 @@ func TestCallerSetPreprocessor(t *testing.T) { } func TestCallerCallSDKError(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - responses.PrepareJSONWriter(w) - w.Write([]byte(` + srv := test.MockHTTPServer(nil) + defer srv.Close() + srv.NextResponse <- ` { "jsonrpc": "2.0", "error": { @@ -273,12 +254,9 @@ func TestCallerCallSDKError(t *testing.T) { ] }, "id": 0 - } - `)) - })) + }` - svc := NewService(sdkrouter.New(map[string]string{"sdk": ts.URL})) - c := svc.NewCaller("") + c := NewCaller(srv.URL, "") hook := logrusTest.NewLocal(Logger.Logger()) response := c.Call(newRawRequest(t, "resolve", map[string]string{"urls": "what"})) var rpcResponse jsonrpc.RPCResponse @@ -293,9 +271,7 @@ func TestCallerCallClientJSONError(t *testing.T) { responses.PrepareJSONWriter(w) w.Write([]byte(`{"method":"version}`)) })) - svc := NewService(sdkrouter.New(map[string]string{"sdk": ts.URL})) - c := svc.NewCaller("") - + c := NewCaller(ts.URL, "") response := c.Call([]byte(`{"method":"version}`)) var rpcResponse jsonrpc.RPCResponse json.Unmarshal(response, &rpcResponse) @@ -304,44 +280,12 @@ func TestCallerCallClientJSONError(t *testing.T) { assert.Equal(t, "unexpected end of JSON input", rpcResponse.Error.Message) } -func TestQueryParamsAsMap(t *testing.T) { - q, err := NewQuery(newRawRequest(t, "version", nil)) - require.NoError(t, err) - assert.Nil(t, q.ParamsAsMap()) - - q, err = NewQuery(newRawRequest(t, "resolve", map[string]string{"urls": "what"})) - require.NoError(t, err) - assert.Equal(t, map[string]interface{}{"urls": "what"}, q.ParamsAsMap()) - - q, err = NewQuery(newRawRequest(t, "account_balance", nil)) - require.NoError(t, err) - - q.SetWalletID("123") - err = q.validate() - require.NoError(t, err, errors.Unwrap(err)) - assert.Equal(t, map[string]interface{}{"wallet_id": "123"}, q.ParamsAsMap()) - - searchParams := map[string]interface{}{ - "any_tags": []interface{}{ - "art", "automotive", "blockchain", "comedy", "economics", "education", - "gaming", "music", "news", "science", "sports", "technology", - }, - } - q, err = NewQuery(newRawRequest(t, "claim_search", searchParams)) - require.NoError(t, err) - assert.Equal(t, searchParams, q.ParamsAsMap()) -} - func TestSDKMethodStatus(t *testing.T) { - svc := NewService(sdkrouter.New(config.GetLbrynetServers())) - c := svc.NewCaller("") - request := newRawRequest(t, "status", nil) - callResult := c.Call(request) - + c := NewService(sdkrouter.New(config.GetLbrynetServers())).NewCaller("") + callResult := c.Call(newRawRequest(t, "status", nil)) var rpcResponse jsonrpc.RPCResponse json.Unmarshal(callResult, &rpcResponse) - result := rpcResponse.Result.(map[string]interface{}) assert.Equal(t, "692EAWhtoqDuAfQ6KHMXxFxt8tkhmt7sfprEMHWKjy5hf6PwZcHDV542VHqRnFnTCD", - result["installation_id"].(string)) + rpcResponse.Result.(map[string]interface{})["installation_id"].(string)) } diff --git a/app/proxy/client.go b/app/proxy/client.go deleted file mode 100644 index 35773a3d..00000000 --- a/app/proxy/client.go +++ /dev/null @@ -1,107 +0,0 @@ -package proxy - -import ( - "errors" - "fmt" - "net/http" - "time" - - "github.com/lbryio/lbrytv/internal/lbrynet" - "github.com/lbryio/lbrytv/internal/metrics" - "github.com/lbryio/lbrytv/internal/monitor" - - ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" - "github.com/ybbus/jsonrpc" -) - -const walletLoadRetries = 3 -const walletLoadRetryWait = time.Millisecond * 100 - -var ClientLogger = monitor.NewModuleLogger("proxy_client") - -type Client struct { - rpcClient jsonrpc.RPCClient - endpoint string - walletID string - retries int -} - -func NewClient(endpoint string, walletID string, timeout time.Duration) Client { - return Client{ - endpoint: endpoint, - rpcClient: jsonrpc.NewClientWithOpts(endpoint, &jsonrpc.RPCClientOpts{ - HTTPClient: &http.Client{Timeout: timeout}, - }), - walletID: walletID, - } -} - -func (c Client) Call(q *Query) (*jsonrpc.RPCResponse, error) { - var ( - r *jsonrpc.RPCResponse - err error - duration float64 - ) - - callMetrics := metrics.ProxyCallDurations.WithLabelValues(q.Method(), c.endpoint) - failureMetrics := metrics.ProxyCallFailedDurations.WithLabelValues(q.Method(), c.endpoint) - - for i := 0; i < walletLoadRetries; i++ { - start := time.Now() - - r, err = c.rpcClient.CallRaw(q.Request) - - duration = time.Since(start).Seconds() - callMetrics.Observe(duration) - - // Generally a HTTP transport failure (connect error etc) - if err != nil { - ClientLogger.Log().Errorf("error sending query to %v: %v", c.endpoint, err) - return nil, err - } - - // This checks if LbrynetServer responded with missing wallet error and tries to reload it, - // then repeats the request again. - if c.isWalletNotLoaded(r) { - time.Sleep(walletLoadRetryWait) - // Using LBRY JSON-RPC client here for easier request/response processing - client := ljsonrpc.NewClient(c.endpoint) - _, err := client.WalletAdd(c.walletID) - // Alert sentry on the last failed wallet load attempt - if err != nil && i >= walletLoadRetries-1 { - errMsg := "gave up on manually adding a wallet: %v" - ClientLogger.WithFields(monitor.F{ - "wallet_id": c.walletID, - "endpoint": c.endpoint, - }).Errorf(errMsg, err) - monitor.CaptureException( - fmt.Errorf(errMsg, err), map[string]string{ - "wallet_id": c.walletID, - "endpoint": c.endpoint, - "retries": fmt.Sprintf("%v", i), - }) - } - } else if c.isWalletAlreadyLoaded(r) { - continue - } else { - break - } - } - - if (r != nil && r.Error != nil) || err != nil { - Logger.LogFailedQuery(q.Method(), c.endpoint, c.walletID, duration, q.Params(), r.Error) - failureMetrics.Observe(duration) - } else { - Logger.LogSuccessfulQuery(q.Method(), c.endpoint, c.walletID, duration, q.Params(), r) - } - - return r, err -} - -func (c *Client) isWalletNotLoaded(r *jsonrpc.RPCResponse) bool { - return r.Error != nil && errors.Is(lbrynet.NewWalletError(0, errors.New(r.Error.Message)), lbrynet.ErrWalletNotLoaded) -} - -func (c *Client) isWalletAlreadyLoaded(r *jsonrpc.RPCResponse) bool { - return r.Error != nil && errors.Is(lbrynet.NewWalletError(0, errors.New(r.Error.Message)), lbrynet.ErrWalletAlreadyLoaded) -} diff --git a/app/proxy/client_test.go b/app/proxy/client_test.go index 76c1b3cc..2cd37dad 100644 --- a/app/proxy/client_test.go +++ b/app/proxy/client_test.go @@ -12,43 +12,6 @@ import ( "github.com/ybbus/jsonrpc" ) -type MockRPCClient struct { - Delay time.Duration - LastRequest jsonrpc.RPCRequest - NextResponse chan *jsonrpc.RPCResponse -} - -func NewMockRPCClient() *MockRPCClient { - return &MockRPCClient{ - NextResponse: make(chan *jsonrpc.RPCResponse, 100), - } -} - -func (c MockRPCClient) AddNextResponse(r *jsonrpc.RPCResponse) { - c.NextResponse <- r -} - -func (c MockRPCClient) Call(method string, params ...interface{}) (*jsonrpc.RPCResponse, error) { - return <-c.NextResponse, nil -} - -func (c *MockRPCClient) CallRaw(request *jsonrpc.RPCRequest) (*jsonrpc.RPCResponse, error) { - c.LastRequest = *request - return <-c.NextResponse, nil -} - -func (c MockRPCClient) CallFor(out interface{}, method string, params ...interface{}) error { - return nil -} - -func (c MockRPCClient) CallBatch(requests jsonrpc.RPCRequests) (jsonrpc.RPCResponses, error) { - return nil, nil -} - -func (c MockRPCClient) CallBatchRaw(requests jsonrpc.RPCRequests) (jsonrpc.RPCResponses, error) { - return nil, nil -} - func TestClientCallDoesReloadWallet(t *testing.T) { rand.Seed(time.Now().UnixNano()) dummyUserID := rand.Intn(100) @@ -59,15 +22,17 @@ func TestClientCallDoesReloadWallet(t *testing.T) { err = rt.UnloadWallet(dummyUserID) require.NoError(t, err) - c := NewClient(rt.GetServer(dummyUserID).Address, walletID, 1*time.Second) - - q, _ := NewQuery(newRawRequest(t, "wallet_balance", nil)) + q, err := NewQuery(newRawRequest(t, "wallet_balance", nil)) + require.NoError(t, err) q.SetWalletID(walletID) - r, err := c.Call(q) + c := NewCaller(rt.GetServer(dummyUserID).Address, walletID) + r, err := c.callQueryWithRetry(q) // err = json.Unmarshal(result, response) require.NoError(t, err) require.Nil(t, r.Error) + + // TODO: check that wallet is actually reloaded? what is this test even testing? } func TestClientCallDoesNotReloadWalletAfterOtherErrors(t *testing.T) { @@ -77,7 +42,7 @@ func TestClientCallDoesNotReloadWalletAfterOtherErrors(t *testing.T) { srv := test.MockHTTPServer(nil) defer srv.Close() - c := NewClient(srv.URL, "", 0) + c := NewCaller(srv.URL, "") q, err := NewQuery(newRawRequest(t, "wallet_balance", nil)) require.NoError(t, err) q.SetWalletID(walletID) @@ -96,10 +61,9 @@ func TestClientCallDoesNotReloadWalletAfterOtherErrors(t *testing.T) { Message: "Wallet at path // was not found", }, }) - srv.NoMoreResponses() }() - r, err := c.Call(q) + r, err := c.callQueryWithRetry(q) require.NoError(t, err) require.Equal(t, "Wallet at path // was not found", r.Error.Message) } @@ -108,29 +72,35 @@ func TestClientCallDoesNotReloadWalletIfAlreadyLoaded(t *testing.T) { rand.Seed(time.Now().UnixNano()) walletID := sdkrouter.WalletID(rand.Intn(100)) - mc := NewMockRPCClient() - c := &Client{rpcClient: mc} - q, _ := NewQuery(newRawRequest(t, "wallet_balance", nil)) + srv := test.MockHTTPServer(nil) + defer srv.Close() + + c := NewCaller(srv.URL, "") + q, err := NewQuery(newRawRequest(t, "wallet_balance", nil)) + require.NoError(t, err) q.SetWalletID(walletID) - mc.AddNextResponse(&jsonrpc.RPCResponse{ - JSONRPC: "2.0", - Error: &jsonrpc.RPCError{ - Message: "Couldn't find wallet: //", - }, - }) - mc.AddNextResponse(&jsonrpc.RPCResponse{ - JSONRPC: "2.0", - Error: &jsonrpc.RPCError{ - Message: "Wallet at path // is already loaded", - }, - }) - mc.AddNextResponse(&jsonrpc.RPCResponse{ - JSONRPC: "2.0", - Result: `"99999.00"`, - }) - - r, err := c.Call(q) + go func() { + srv.NextResponse <- test.ResToStr(t, jsonrpc.RPCResponse{ + JSONRPC: "2.0", + Error: &jsonrpc.RPCError{ + Message: "Couldn't find wallet: //", + }, + }) + srv.NextResponse <- "" // for the wallet_add call + srv.NextResponse <- test.ResToStr(t, jsonrpc.RPCResponse{ + JSONRPC: "2.0", + Error: &jsonrpc.RPCError{ + Message: "Wallet at path // is already loaded", + }, + }) + srv.NextResponse <- test.ResToStr(t, jsonrpc.RPCResponse{ + JSONRPC: "2.0", + Result: `"99999.00"`, + }) + }() + + r, err := c.callQueryWithRetry(q) require.NoError(t, err) require.Nil(t, r.Error) diff --git a/app/proxy/query.go b/app/proxy/query.go index a6df17b5..21335190 100644 --- a/app/proxy/query.go +++ b/app/proxy/query.go @@ -6,8 +6,11 @@ import ( "fmt" "strings" - ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" "github.com/lbryio/lbrytv/internal/monitor" + "github.com/lbryio/lbrytv/internal/responses" + + ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" + "github.com/ybbus/jsonrpc" ) @@ -152,3 +155,32 @@ func (q *Query) predefinedResponse() *jsonrpc.RPCResponse { return nil } } + +func methodInList(method string, checkMethods []string) bool { + for _, m := range checkMethods { + if m == method { + return true + } + } + return false +} + +// getPreconditionedQueryResponse returns true if we got a resolve query with more than `cacheResolveLongerThan` urls in it +func getPreconditionedQueryResponse(method string, params interface{}) *jsonrpc.RPCResponse { + if methodInList(method, forbiddenMethods) { + return responses.NewJSONRPCError(fmt.Sprintf("Forbidden method requested: %v", method), ErrMethodUnavailable) + } + + if paramsMap, ok := params.(map[string]interface{}); ok { + if _, ok := paramsMap[forbiddenParam]; ok { + return responses.NewJSONRPCError(fmt.Sprintf("Forbidden parameter supplied: %v", forbiddenParam), ErrInvalidParams) + } + } + + if method == MethodStatus { + var r jsonrpc.RPCResponse + r.Result = getStatusResponse() + return &r + } + return nil +} diff --git a/app/proxy/query_filters.go b/app/proxy/query_filters.go deleted file mode 100644 index 002db946..00000000 --- a/app/proxy/query_filters.go +++ /dev/null @@ -1,38 +0,0 @@ -package proxy - -import ( - "fmt" - - "github.com/lbryio/lbrytv/internal/responses" - - "github.com/ybbus/jsonrpc" -) - -func methodInList(method string, checkMethods []string) bool { - for _, m := range checkMethods { - if m == method { - return true - } - } - return false -} - -// getPreconditionedQueryResponse returns true if we got a resolve query with more than `cacheResolveLongerThan` urls in it -func getPreconditionedQueryResponse(method string, params interface{}) *jsonrpc.RPCResponse { - if methodInList(method, forbiddenMethods) { - return responses.NewJSONRPCError(fmt.Sprintf("Forbidden method requested: %v", method), ErrMethodUnavailable) - } - - if paramsMap, ok := params.(map[string]interface{}); ok { - if _, ok := paramsMap[forbiddenParam]; ok { - return responses.NewJSONRPCError(fmt.Sprintf("Forbidden parameter supplied: %v", forbiddenParam), ErrInvalidParams) - } - } - - if method == MethodStatus { - var r jsonrpc.RPCResponse - r.Result = getStatusResponse() - return &r - } - return nil -} diff --git a/app/proxy/query_test.go b/app/proxy/query_test.go new file mode 100644 index 00000000..cac0db11 --- /dev/null +++ b/app/proxy/query_test.go @@ -0,0 +1,37 @@ +package proxy + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestQueryParamsAsMap(t *testing.T) { + q, err := NewQuery(newRawRequest(t, "version", nil)) + require.NoError(t, err) + assert.Nil(t, q.ParamsAsMap()) + + q, err = NewQuery(newRawRequest(t, "resolve", map[string]string{"urls": "what"})) + require.NoError(t, err) + assert.Equal(t, map[string]interface{}{"urls": "what"}, q.ParamsAsMap()) + + q, err = NewQuery(newRawRequest(t, "account_balance", nil)) + require.NoError(t, err) + + q.SetWalletID("123") + err = q.validate() + require.NoError(t, err, errors.Unwrap(err)) + assert.Equal(t, map[string]interface{}{"wallet_id": "123"}, q.ParamsAsMap()) + + searchParams := map[string]interface{}{ + "any_tags": []interface{}{ + "art", "automotive", "blockchain", "comedy", "economics", "education", + "gaming", "music", "news", "science", "sports", "technology", + }, + } + q, err = NewQuery(newRawRequest(t, "claim_search", searchParams)) + require.NoError(t, err) + assert.Equal(t, searchParams, q.ParamsAsMap()) +} diff --git a/app/proxy/service.go b/app/proxy/service.go new file mode 100644 index 00000000..ba5543ed --- /dev/null +++ b/app/proxy/service.go @@ -0,0 +1,45 @@ +package proxy + +import ( + "net/http" + "time" + + "github.com/lbryio/lbrytv/app/sdkrouter" + + "github.com/ybbus/jsonrpc" +) + +const defaultRPCTimeout = 30 * time.Second + +// Service generates Caller objects and keeps execution time metrics +// for all calls proxied through those objects. +type Service struct { + SDKRouter *sdkrouter.Router + rpcTimeout time.Duration +} + +// NewService is the entry point to proxy module. +// Normally only one instance of Service should be created per running server. +func NewService(router *sdkrouter.Router) *Service { + return &Service{ + SDKRouter: router, + rpcTimeout: defaultRPCTimeout, + } +} + +func (ps *Service) SetRPCTimeout(timeout time.Duration) { + ps.rpcTimeout = timeout +} + +// NewCaller returns an instance of Caller ready to proxy requests. +// Note that `SetWalletID` needs to be called if an authenticated user is making this call. +func (ps *Service) NewCaller(walletID string) *Caller { + endpoint := ps.SDKRouter.GetServer(sdkrouter.UserID(walletID)).Address + return &Caller{ + endpoint: endpoint, + walletID: walletID, + client: jsonrpc.NewClientWithOpts(endpoint, &jsonrpc.RPCClientOpts{ + HTTPClient: &http.Client{Timeout: ps.rpcTimeout}, + }), + } +} diff --git a/internal/test/test.go b/internal/test/test.go index 3abb0ee9..45acb16d 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -16,33 +16,37 @@ type MockServer struct { NextResponse chan<- string } -func (m *MockServer) NoMoreResponses() { - close(m.NextResponse) -} - -type RequestData struct { +type Request struct { R *http.Request Body string } // MockHTTPServer creates an http server that can be used to test clients // NOTE: if you want to make sure that you get requests in your requestChan one by one, limit the -// channel to a buffer size of 1. then writes to the chan will block until you read it -func MockHTTPServer(requestChan chan *RequestData) *MockServer { +// channel to a buffer size of 1. then writes to the chan will block until you read it. see +// ReqChan() for how to do this +func MockHTTPServer(requestChan chan *Request) *MockServer { next := make(chan string, 1) return &MockServer{ NextResponse: next, Server: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - data, _ := ioutil.ReadAll(r.Body) defer r.Body.Close() if requestChan != nil { - requestChan <- &RequestData{r, string(data)} // store the request for inspection + data, _ := ioutil.ReadAll(r.Body) + requestChan <- &Request{r, string(data)} } fmt.Fprintf(w, <-next) })), } } +// ReqChan makes a channel for reading received requests one by one. +// Use it in conjunction with MockHTTPServer +func ReqChan() chan *Request { + return make(chan *Request, 1) +} + +// ReqToStr is a convenience method func ReqToStr(t *testing.T, req jsonrpc.RPCRequest) string { r, err := json.Marshal(req) if err != nil { @@ -51,6 +55,7 @@ func ReqToStr(t *testing.T, req jsonrpc.RPCRequest) string { return string(r) } +// StrToReq is a convenience method func StrToReq(t *testing.T, req string) jsonrpc.RPCRequest { var r jsonrpc.RPCRequest err := json.Unmarshal([]byte(req), &r) @@ -60,6 +65,7 @@ func StrToReq(t *testing.T, req string) jsonrpc.RPCRequest { return r } +// ResToStr is a convenience method func ResToStr(t *testing.T, res jsonrpc.RPCResponse) string { r, err := json.Marshal(res) if err != nil { diff --git a/internal/test/test_test.go b/internal/test/test_test.go index f93152c2..78d35cf6 100644 --- a/internal/test/test_test.go +++ b/internal/test/test_test.go @@ -10,21 +10,21 @@ import ( ) func TestMockRPCServer(t *testing.T) { - reqChan := make(chan *RequestData, 1) + reqChan := ReqChan() rpcServer := MockHTTPServer(reqChan) defer rpcServer.Close() rpcServer.NextResponse <- `{"result": {"items": [], "page": 1, "page_size": 2, "total_pages": 3}}` - rsp, err := ljsonrpc.NewClient(rpcServer.URL).WalletList("", 1, 2) + res, err := ljsonrpc.NewClient(rpcServer.URL).WalletList("", 1, 2) if err != nil { t.Error(err) } - req := <-reqChan // read the request for inspection + req := <-reqChan assert.Equal(t, req.R.Method, http.MethodPost) assert.Equal(t, req.Body, `{"method":"wallet_list","params":{"page":1,"page_size":2},"id":0,"jsonrpc":"2.0"}`) - assert.Equal(t, rsp.Page, uint64(1)) - assert.Equal(t, rsp.PageSize, uint64(2)) - assert.Equal(t, rsp.TotalPages, uint64(3)) + assert.Equal(t, res.Page, uint64(1)) + assert.Equal(t, res.PageSize, uint64(2)) + assert.Equal(t, res.TotalPages, uint64(3)) } From ac319d6861501ffaa4f3b3f13b6a7c5376cb010d Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Mon, 13 Apr 2020 10:52:57 -0400 Subject: [PATCH 04/18] straighten out error codes --- api/routes_test.go | 2 +- app/proxy/caller.go | 53 +++++++++-------- app/proxy/caller_test.go | 6 +- app/proxy/errors.go | 121 ++++++++++----------------------------- app/proxy/handlers.go | 2 +- app/proxy/processors.go | 91 +++++++++++++---------------- app/proxy/query.go | 34 ++--------- app/publish/errors.go | 47 --------------- app/publish/publish.go | 12 ++-- 9 files changed, 112 insertions(+), 256 deletions(-) delete mode 100644 app/publish/errors.go diff --git a/api/routes_test.go b/api/routes_test.go index 63a663a9..3318fe2b 100644 --- a/api/routes_test.go +++ b/api/routes_test.go @@ -44,7 +44,7 @@ func TestRoutesPublish(t *testing.T) { assert.Equal(t, http.StatusOK, rr.Code) // Authentication Required error here is enough to see that the request // has been dispatched through the publish handler - assert.Contains(t, rr.Body.String(), `"code": -32080`) + assert.Contains(t, rr.Body.String(), `"code": -32084`) } func TestRoutesOptions(t *testing.T) { diff --git a/app/proxy/caller.go b/app/proxy/caller.go index e9158e4a..073b020a 100644 --- a/app/proxy/caller.go +++ b/app/proxy/caller.go @@ -37,10 +37,12 @@ const ( // Caller patches through JSON-RPC requests from clients, doing pre/post-processing, // account processing and validation. type Caller struct { - client jsonrpc.RPCClient - walletID string - endpoint string - preprocessor func(q *Query) + // Preprocessor is applied to query before it's sent to the SDK. + Preprocessor func(q *Query) + + client jsonrpc.RPCClient + walletID string + endpoint string } func NewCaller(endpoint, walletID string) *Caller { @@ -56,25 +58,27 @@ func NewCaller(endpoint, walletID string) *Caller { func (c *Caller) Call(rawQuery []byte) []byte { r, err := c.call(rawQuery) if err != nil { - if !errors.As(err, &InputError{}) { + if !isJSONParseError(err) { monitor.CaptureException(err, map[string]string{"query": string(rawQuery), "response": fmt.Sprintf("%v", r)}) Logger.Errorf("error calling lbrynet: %v, query: %s", err, rawQuery) } - return c.marshalError(err) + return marshalError(err) } - serialized, err := c.marshal(r) + + serialized, err := marshalResponse(r) if err != nil { monitor.CaptureException(err) Logger.Errorf("error marshaling response: %v", err) - return c.marshalError(err) + return marshalError(err) } + return serialized } -func (c *Caller) call(rawQuery []byte) (*jsonrpc.RPCResponse, CallError) { +func (c *Caller) call(rawQuery []byte) (*jsonrpc.RPCResponse, error) { q, err := NewQuery(rawQuery) if err != nil { - return nil, NewInputError(err) + return nil, err } if c.walletID != "" { @@ -93,18 +97,18 @@ func (c *Caller) call(rawQuery []byte) (*jsonrpc.RPCResponse, CallError) { return pr, nil } - if c.preprocessor != nil { - c.preprocessor(q) + if c.Preprocessor != nil { + c.Preprocessor(q) } r, err := c.callQueryWithRetry(q) if err != nil { - return r, NewInternalError(err) + return r, NewSDKError(err) } - r, err = processResponse(q.Request, r) + err = postProcessResponse(r, q.Request) if err != nil { - return r, NewInternalError(err) + return r, NewSDKError(err) } if q.isCacheable() { @@ -175,25 +179,20 @@ func (c *Caller) callQueryWithRetry(q *Query) (*jsonrpc.RPCResponse, error) { return r, err } -func (c *Caller) marshal(r *jsonrpc.RPCResponse) ([]byte, CallError) { +func marshalResponse(r *jsonrpc.RPCResponse) ([]byte, error) { serialized, err := json.MarshalIndent(r, "", " ") if err != nil { - return nil, NewError(err) + return nil, NewInternalError(err) } return serialized, nil } -func (c *Caller) marshalError(e CallError) []byte { - serialized, err := json.MarshalIndent(e.AsRPCResponse(), "", " ") - if err != nil { - return []byte(err.Error()) +func marshalError(err error) []byte { + var rpcErr RPCError + if errors.As(err, &rpcErr) { + return rpcErr.JSON() } - return serialized -} - -// SetPreprocessor applies provided function to query before it's sent to the LbrynetServer. -func (c *Caller) SetPreprocessor(p func(q *Query)) { - c.preprocessor = p + return []byte(err.Error()) } func isErrWalletNotLoaded(r *jsonrpc.RPCResponse) bool { diff --git a/app/proxy/caller_test.go b/app/proxy/caller_test.go index 1f1b4bcf..af12e8bb 100644 --- a/app/proxy/caller_test.go +++ b/app/proxy/caller_test.go @@ -204,7 +204,7 @@ func TestCallerSetPreprocessor(t *testing.T) { c := NewCaller(srv.URL, "") - c.SetPreprocessor(func(q *Query) { + c.Preprocessor = func(q *Query) { params := q.ParamsAsMap() if params == nil { q.Request.Params = map[string]string{"param": "123"} @@ -212,7 +212,7 @@ func TestCallerSetPreprocessor(t *testing.T) { params["param"] = "123" q.Request.Params = params } - }) + } srv.NextResponse <- "" @@ -276,7 +276,7 @@ func TestCallerCallClientJSONError(t *testing.T) { var rpcResponse jsonrpc.RPCResponse json.Unmarshal(response, &rpcResponse) assert.Equal(t, "2.0", rpcResponse.JSONRPC) - assert.Equal(t, ErrJSONParse, rpcResponse.Error.Code) + assert.Equal(t, rpcErrorCodeJSONParse, rpcResponse.Error.Code) assert.Equal(t, "unexpected end of JSON input", rpcResponse.Error.Message) } diff --git a/app/proxy/errors.go b/app/proxy/errors.go index 940388e9..3686a8e0 100644 --- a/app/proxy/errors.go +++ b/app/proxy/errors.go @@ -1,113 +1,54 @@ package proxy import ( - "fmt" + "encoding/json" + "errors" "github.com/ybbus/jsonrpc" ) -// ErrProxy is for general errors that originate inside the proxy module -const ErrProxy int = -32080 - -// ErrInternal is a general server error code -const ErrInternal int = -32603 - -// ErrAuthFailed is when supplied auth_token / account_id is not present in the database. -const ErrAuthFailed int = -32085 - -// ErrJSONParse means invalid JSON was received by the server. -const ErrJSONParse int = -32700 - -// ErrInvalidParams signifies a client-supplied params error -const ErrInvalidParams int = -32602 - -// ErrInvalidRequest signifies a general client error -const ErrInvalidRequest int = -32600 - -// ErrMethodUnavailable means the client-requested method cannot be found -const ErrMethodUnavailable int = -32601 - -// CallError is for whatever errors might occur when processing or forwarding client JSON-RPC request -type CallError interface { - AsRPCResponse() *jsonrpc.RPCResponse - Code() int - Error() string -} +const ( + rpcErrorCodeInternal int = -32080 // general errors that originate inside the proxy module + rpcErrorCodeSDK int = -32603 // otherwise-unspecified errors from the SDK + rpcErrorCodeAuthRequired int = -32084 // auth info is required but is not provided + rpcErrorCodeUnauthorized int = -32085 // auth info is provided but is not found in the database + rpcErrorCodeJSONParse int = -32700 // invalid JSON was received by the server + rpcErrorCodeInvalidParams int = -32602 // error in params that the client provided + rpcErrorCodeMethodNotAllowed int = -32601 // the requested method is not allowed to be called +) -type GenericError struct { +type RPCError struct { err error code int } -// InputError is a client JSON parsing error -type InputError struct { - GenericError -} - -// AuthFailed is for authentication failures when jsonrpc client has provided a token -type AuthFailed struct { - err error -} +func (e RPCError) Error() string { return e.err.Error() } +func (e RPCError) Code() int { return e.code } +func (e RPCError) Unwrap() error { return e.err } -// AsRPCResponse returns error as jsonrpc.RPCResponse -func (e GenericError) AsRPCResponse() *jsonrpc.RPCResponse { - return &jsonrpc.RPCResponse{ +func (e RPCError) JSON() []byte { + b, err := json.MarshalIndent(jsonrpc.RPCResponse{ Error: &jsonrpc.RPCError{ Code: e.Code(), Message: e.Error(), }, JSONRPC: "2.0", + }, "", " ") + if err != nil { + Logger.Errorf("rpc error to json: %v", err) } + return b } -// NewError is for general internal errors -func NewError(e error) GenericError { - return GenericError{e, ErrInternal} -} - -// NewInputError is for client JSON parsing errors -func NewInputError(e error) InputError { - return InputError{GenericError{e, ErrJSONParse}} -} - -// NewMethodError creates a call method error -func NewMethodError(e error) GenericError { - return GenericError{e, ErrMethodUnavailable} -} - -// NewParamsError signifies an error in method parameters -func NewParamsError(e error) GenericError { - return GenericError{e, ErrInvalidParams} -} - -// NewInternalError is for SDK-related errors (connection problems etc) -func NewInternalError(e error) GenericError { - return GenericError{e, ErrInternal} -} - -func (e GenericError) Error() string { - return e.err.Error() -} - -// Code returns JSRON-RPC error code -func (e GenericError) Code() int { - return e.code -} - -func (e GenericError) Unwrap() error { - return e.err -} - -func (e AuthFailed) Error() string { - return fmt.Sprintf("couldn't find account for in lbrynet") -} - -// Code returns JSRON-RPC error code -func (e AuthFailed) Code() int { - return ErrAuthFailed -} +func NewInternalError(e error) RPCError { return RPCError{e, rpcErrorCodeInternal} } +func NewJSONParseError(e error) RPCError { return RPCError{e, rpcErrorCodeJSONParse} } +func NewMethodNotAllowedError(e error) RPCError { return RPCError{e, rpcErrorCodeMethodNotAllowed} } +func NewInvalidParamsError(e error) RPCError { return RPCError{e, rpcErrorCodeInvalidParams} } +func NewSDKError(e error) RPCError { return RPCError{e, rpcErrorCodeSDK} } +func NewUnauthorizedError(e error) RPCError { return RPCError{e, rpcErrorCodeUnauthorized} } +func NewAuthRequiredError(e error) RPCError { return RPCError{e, rpcErrorCodeAuthRequired} } -// Code returns JSRON-RPC error code -func (e InputError) Code() int { - return ErrJSONParse +func isJSONParseError(err error) bool { + var e RPCError + return err != nil && errors.As(err, &e) && e.code == rpcErrorCodeJSONParse } diff --git a/app/proxy/handlers.go b/app/proxy/handlers.go index 6df640a0..d024f6f3 100644 --- a/app/proxy/handlers.go +++ b/app/proxy/handlers.go @@ -46,7 +46,7 @@ func (rh *RequestHandler) Handle(w http.ResponseWriter, r *http.Request) { walletID, err = auth.GetWalletID(r) if err != nil { - responses.JSONRPCError(w, err.Error(), ErrAuthFailed) + responses.JSONRPCError(w, err.Error(), rpcErrorCodeUnauthorized) monitor.CaptureRequestError(err, r, w) return } diff --git a/app/proxy/processors.go b/app/proxy/processors.go index d306de9e..8abc7370 100644 --- a/app/proxy/processors.go +++ b/app/proxy/processors.go @@ -7,66 +7,56 @@ import ( "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/monitor" - log "github.com/sirupsen/logrus" ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" + + log "github.com/sirupsen/logrus" "github.com/ybbus/jsonrpc" ) -func processQuery(query *jsonrpc.RPCRequest) (processedQuery *jsonrpc.RPCRequest, err error) { - processedQuery = query +func postProcessResponse(response *jsonrpc.RPCResponse, query *jsonrpc.RPCRequest) error { switch query.Method { case MethodGet: - processedQuery, err = queryProcessorGet(query) - } - return processedQuery, err -} - -func processResponse(query *jsonrpc.RPCRequest, response *jsonrpc.RPCResponse) (processedResponse *jsonrpc.RPCResponse, err error) { - processedResponse = response - switch query.Method { - case MethodGet: - processedResponse, err = responseProcessorGet(query, response) + return responseProcessorGet(response, query) case MethodFileList: - processedResponse, err = responseProcessorFileList(query, response) + return responseProcessorFileList(response) case MethodAccountList: - processedResponse, err = responseProcessorAccountList(query, response) + return responseProcessorAccountList(response, query) + default: + return nil } - return processedResponse, err -} - -func queryProcessorGet(query *jsonrpc.RPCRequest) (*jsonrpc.RPCRequest, error) { - return query, nil } -func responseProcessorGet(query *jsonrpc.RPCRequest, response *jsonrpc.RPCResponse) (*jsonrpc.RPCResponse, error) { - var err error - result := map[string]interface{}{} - response.GetObject(&result) +func responseProcessorGet(response *jsonrpc.RPCResponse, query *jsonrpc.RPCRequest) error { + var result map[string]interface{} + err := response.GetObject(&result) + if err != nil { + return err + } stringifiedParams, err := json.Marshal(query.Params) if err != nil { - return response, err + return err } - queryParams := map[string]interface{}{} + var queryParams map[string]interface{} err = json.Unmarshal(stringifiedParams, &queryParams) if err != nil { - return response, err + return err } + result["download_path"] = fmt.Sprintf( "%s%s/%s", config.GetConfig().Viper.GetString("BaseContentURL"), queryParams["uri"], result["outpoint"]) + response.Result = result - return response, nil + return nil } -func responseProcessorFileList(query *jsonrpc.RPCRequest, response *jsonrpc.RPCResponse) (*jsonrpc.RPCResponse, error) { - var err error +func responseProcessorFileList(response *jsonrpc.RPCResponse) error { var resultArray []map[string]interface{} - response.GetObject(&resultArray) - + err := response.GetObject(&resultArray) if err != nil { - return response, err + return err } if len(resultArray) != 0 { @@ -76,36 +66,33 @@ func responseProcessorFileList(query *jsonrpc.RPCRequest, response *jsonrpc.RPCR resultArray[0]["claim_name"], resultArray[0]["claim_id"], resultArray[0]["file_name"]) } - response.Result = resultArray - return response, nil -} -func getDefaultAccount(accounts *ljsonrpc.AccountListResponse) *ljsonrpc.Account { - for _, account := range accounts.Items { - if account.IsDefault { - return &account - } - } + response.Result = resultArray return nil } -func responseProcessorAccountList(query *jsonrpc.RPCRequest, response *jsonrpc.RPCResponse) (*jsonrpc.RPCResponse, error) { - accounts := new(ljsonrpc.AccountListResponse) - // result := map[string]interface{}{} - // response.GetObject(&result) +func responseProcessorAccountList(response *jsonrpc.RPCResponse, query *jsonrpc.RPCRequest) error { + monitor.Logger.WithFields(log.Fields{"params": query.Params}).Info("got account_list query") - monitor.Logger.WithFields(log.Fields{ - "params": query.Params, - }).Info("got account_list query") if query.Params == nil { + accounts := new(ljsonrpc.AccountListResponse) // No account_id is supplied, get the default account and return it ljsonrpc.Decode(response.Result, accounts) account := getDefaultAccount(accounts) if account == nil { - return nil, errors.New("fatal error: no default account found") + return errors.New("fatal error: no default account found") } response.Result = account } - // response.Result = result - return response, nil + + return nil +} + +func getDefaultAccount(accounts *ljsonrpc.AccountListResponse) *ljsonrpc.Account { + for _, account := range accounts.Items { + if account.IsDefault { + return &account + } + } + return nil } diff --git a/app/proxy/query.go b/app/proxy/query.go index 21335190..de26d16b 100644 --- a/app/proxy/query.go +++ b/app/proxy/query.go @@ -6,10 +6,8 @@ import ( "fmt" "strings" - "github.com/lbryio/lbrytv/internal/monitor" - "github.com/lbryio/lbrytv/internal/responses" - ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" + "github.com/lbryio/lbrytv/internal/monitor" "github.com/ybbus/jsonrpc" ) @@ -27,7 +25,7 @@ func NewQuery(r []byte) (*Query, error) { q := &Query{rawRequest: r, Request: &jsonrpc.RPCRequest{}} err := q.unmarshal() if err != nil { - return nil, err + return nil, NewJSONParseError(err) } return q, nil } @@ -43,20 +41,20 @@ func (q *Query) unmarshal() error { return nil } -func (q *Query) validate() CallError { +func (q *Query) validate() error { if !methodInList(q.Method(), relaxedMethods) && !methodInList(q.Method(), walletSpecificMethods) { - return NewMethodError(errors.New("forbidden method")) + return NewMethodNotAllowedError(errors.New("forbidden method")) } if q.ParamsAsMap() != nil { if _, ok := q.ParamsAsMap()[forbiddenParam]; ok { - return NewParamsError(fmt.Errorf("forbidden parameter supplied: %v", forbiddenParam)) + return NewInvalidParamsError(fmt.Errorf("forbidden parameter supplied: %v", forbiddenParam)) } } if !methodInList(q.Method(), relaxedMethods) { if q.walletID == "" { - return NewParamsError(errors.New("account identifier required")) + return NewInvalidParamsError(errors.New("account identifier required")) } if p := q.ParamsAsMap(); p != nil { p[paramWalletID] = q.walletID @@ -164,23 +162,3 @@ func methodInList(method string, checkMethods []string) bool { } return false } - -// getPreconditionedQueryResponse returns true if we got a resolve query with more than `cacheResolveLongerThan` urls in it -func getPreconditionedQueryResponse(method string, params interface{}) *jsonrpc.RPCResponse { - if methodInList(method, forbiddenMethods) { - return responses.NewJSONRPCError(fmt.Sprintf("Forbidden method requested: %v", method), ErrMethodUnavailable) - } - - if paramsMap, ok := params.(map[string]interface{}); ok { - if _, ok := paramsMap[forbiddenParam]; ok { - return responses.NewJSONRPCError(fmt.Sprintf("Forbidden parameter supplied: %v", forbiddenParam), ErrInvalidParams) - } - } - - if method == MethodStatus { - var r jsonrpc.RPCResponse - r.Result = getStatusResponse() - return &r - } - return nil -} diff --git a/app/publish/errors.go b/app/publish/errors.go deleted file mode 100644 index 6ebc8452..00000000 --- a/app/publish/errors.go +++ /dev/null @@ -1,47 +0,0 @@ -package publish - -import ( - "encoding/json" - - "github.com/lbryio/lbrytv/app/proxy" - - "github.com/ybbus/jsonrpc" -) - -type Error struct { - code int - message string -} - -func (e Error) AsRPCResponse() *jsonrpc.RPCResponse { - return &jsonrpc.RPCResponse{ - Error: &jsonrpc.RPCError{ - Code: e.Code(), - Message: e.Message(), - }, - JSONRPC: "2.0", - } -} - -func (e Error) AsBytes() []byte { - b, _ := json.MarshalIndent(e.AsRPCResponse(), "", " ") - return b -} - -func (e Error) Code() int { - return e.code -} - -func (e Error) Message() string { - return e.message -} - -var ErrUnauthorized = Error{code: proxy.ErrProxy, message: "authentication required"} - -func NewAuthError(err error) Error { - return Error{code: proxy.ErrAuthFailed, message: err.Error()} -} - -func NewInternalError(err error) Error { - return Error{code: proxy.ErrInternal, message: err.Error()} -} diff --git a/app/publish/publish.go b/app/publish/publish.go index 6956eda7..31433663 100644 --- a/app/publish/publish.go +++ b/app/publish/publish.go @@ -80,11 +80,11 @@ func NewUploadHandler(opts UploadOpts) (*UploadHandler, error) { // Resulting response is then returned back as a slice of bytes. func (p *LbrynetPublisher) Publish(filePath, walletID string, rawQuery []byte) []byte { c := p.Service.NewCaller(walletID) - c.SetPreprocessor(func(q *proxy.Query) { + c.Preprocessor = func(q *proxy.Query) { params := q.ParamsAsMap() params[fileNameParam] = filePath q.Request.Params = params - }) + } r := c.Call(rawQuery) return r } @@ -95,13 +95,11 @@ func (p *LbrynetPublisher) Publish(filePath, walletID string, rawQuery []byte) [ func (h UploadHandler) Handle(w http.ResponseWriter, r *users.AuthenticatedRequest) { w.WriteHeader(http.StatusOK) if !r.IsAuthenticated() { - var authErr Error if r.AuthFailed() { - authErr = NewAuthError(r.AuthError) + w.Write(proxy.NewUnauthorizedError(r.AuthError).JSON()) } else { - authErr = ErrUnauthorized + w.Write(proxy.NewAuthRequiredError(errors.New("authentication required")).JSON()) } - w.Write(authErr.AsBytes()) return } @@ -109,7 +107,7 @@ func (h UploadHandler) Handle(w http.ResponseWriter, r *users.AuthenticatedReque if err != nil { logger.Log().Error(err) monitor.CaptureException(err) - w.Write(NewInternalError(err).AsBytes()) + w.Write(proxy.NewInternalError(err).JSON()) return } From 79a4d07d0a2855dfc8214cfc0833eed3dd36b977 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Mon, 13 Apr 2020 18:06:32 -0400 Subject: [PATCH 05/18] slow progress --- api/benchmarks_test.go | 9 +- api/routes.go | 2 +- app/proxy/caller.go | 2 +- app/proxy/caller_test.go | 32 +-- app/proxy/client_test.go | 11 +- app/proxy/const.go | 9 - app/proxy/handlers.go | 16 +- app/proxy/handlers_test.go | 10 +- app/proxy/main_test.go | 4 +- app/publish/handler_test.go | 15 +- app/publish/publish.go | 17 +- app/publish/publish_test.go | 23 +- app/publish/testing.go | 4 +- app/sdkrouter/sdkrouter.go | 81 ------- app/sdkrouter/sdkrouter_test.go | 43 ---- app/users/authenticator.go | 61 +++--- app/users/authenticator_test.go | 76 +++---- app/users/testing.go | 62 +----- app/users/testing_test.go | 26 +-- app/users/users.go | 130 ----------- app/{users => wallet}/remote.go | 11 +- app/wallet/wallet.go | 207 ++++++++++++++++++ .../users_test.go => wallet/wallet_test.go} | 167 ++++++++++---- internal/responses/responses.go | 26 +-- internal/responses/responses_test.go | 34 --- internal/test/test.go | 12 +- internal/test/test_test.go | 23 +- 27 files changed, 540 insertions(+), 573 deletions(-) delete mode 100644 app/users/users.go rename app/{users => wallet}/remote.go (77%) create mode 100644 app/wallet/wallet.go rename app/{users/users_test.go => wallet/wallet_test.go} (50%) delete mode 100644 internal/responses/responses_test.go diff --git a/api/benchmarks_test.go b/api/benchmarks_test.go index aaa418ea..c5bb75d8 100644 --- a/api/benchmarks_test.go +++ b/api/benchmarks_test.go @@ -16,7 +16,7 @@ import ( "github.com/lbryio/lbrytv/app/proxy" "github.com/lbryio/lbrytv/app/sdkrouter" - "github.com/lbryio/lbrytv/app/users" + "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/responses" "github.com/lbryio/lbrytv/internal/storage" @@ -32,7 +32,7 @@ func launchAuthenticatingAPIServer() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t := r.PostFormValue("auth_token") - responses.PrepareJSONWriter(w) + responses.AddJSONContentType(w) reply := fmt.Sprintf(` { @@ -94,9 +94,8 @@ func BenchmarkWalletCommands(b *testing.B) { walletsNum := 30 wallets := make([]*models.User, walletsNum) rt := sdkrouter.New(config.GetLbrynetServers()) - svc := users.NewWalletService(rt) - svc.Logger.Disable() + wallet.DisableLogger() sdkrouter.DisableLogger() log.SetOutput(ioutil.Discard) @@ -104,7 +103,7 @@ func BenchmarkWalletCommands(b *testing.B) { for i := 0; i < walletsNum; i++ { uid := int(rand.Int31()) - u, err := svc.Retrieve(users.Query{Token: fmt.Sprintf("%v", uid)}) + u, err := wallet.GetUserWithWallet(rt, fmt.Sprintf("%v", uid), "") require.NoError(b, err, errors.Unwrap(err)) require.NotNil(b, u) wallets[i] = u diff --git a/api/routes.go b/api/routes.go index e0becb03..e6f3316e 100644 --- a/api/routes.go +++ b/api/routes.go @@ -34,7 +34,7 @@ func InstallRoutes(proxyService *proxy.Service, r *mux.Router) { }) v1Router := r.PathPrefix("/api/v1").Subrouter() - v1Router.HandleFunc("/proxy", proxyHandler.HandleOptions).Methods(http.MethodOptions) + v1Router.HandleFunc("/proxy", proxy.HandleCORS).Methods(http.MethodOptions) v1Router.HandleFunc("/proxy", authenticator.Wrap(upHandler.Handle)).MatcherFunc(upHandler.CanHandle) v1Router.HandleFunc("/proxy", proxyHandler.Handle) v1Router.HandleFunc("/metric/ui", metrics.TrackUIMetric).Methods(http.MethodPost) diff --git a/app/proxy/caller.go b/app/proxy/caller.go index 073b020a..f37ee1da 100644 --- a/app/proxy/caller.go +++ b/app/proxy/caller.go @@ -192,7 +192,7 @@ func marshalError(err error) []byte { if errors.As(err, &rpcErr) { return rpcErr.JSON() } - return []byte(err.Error()) + return NewInternalError(err).JSON() } func isErrWalletNotLoaded(r *jsonrpc.RPCResponse) bool { diff --git a/app/proxy/caller_test.go b/app/proxy/caller_test.go index af12e8bb..4e1ff1b2 100644 --- a/app/proxy/caller_test.go +++ b/app/proxy/caller_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/lbryio/lbrytv/app/sdkrouter" + "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/responses" "github.com/lbryio/lbrytv/internal/test" @@ -38,11 +39,13 @@ func newRawRequest(t *testing.T, method string, params interface{}) []byte { return body } -func parseRawResponse(t *testing.T, rawCallReponse []byte, destinationVar interface{}) { - var rpcResponse jsonrpc.RPCResponse +func parseRawResponse(t *testing.T, rawCallReponse []byte, v interface{}) { assert.NotNil(t, rawCallReponse) - json.Unmarshal(rawCallReponse, &rpcResponse) - rpcResponse.GetObject(destinationVar) + var res jsonrpc.RPCResponse + err := json.Unmarshal(rawCallReponse, &res) + require.NoError(t, err) + err = res.GetObject(v) + require.NoError(t, err) } func TestNewQuery(t *testing.T) { @@ -90,11 +93,6 @@ func TestCallerSetWalletID(t *testing.T) { } func TestCallerCallResolve(t *testing.T) { - var ( - errorResponse jsonrpc.RPCResponse - resolveResponse ljsonrpc.ResolveResponse - ) - svc := NewService(sdkrouter.New(config.GetLbrynetServers())) resolvedURL := "what#6769855a9aa43b67086f9ff3c1a5bacb5698a27a" @@ -102,33 +100,35 @@ func TestCallerCallResolve(t *testing.T) { request := newRawRequest(t, "resolve", map[string]string{"urls": resolvedURL}) rawCallReponse := svc.NewCaller("").Call(request) + + var errorResponse jsonrpc.RPCResponse err := json.Unmarshal(rawCallReponse, &errorResponse) require.NoError(t, err) require.Nil(t, errorResponse.Error) + var resolveResponse ljsonrpc.ResolveResponse parseRawResponse(t, rawCallReponse, &resolveResponse) assert.Equal(t, resolvedClaimID, resolveResponse[resolvedURL].ClaimID) } func TestCallerCallWalletBalance(t *testing.T) { - var accountBalanceResponse ljsonrpc.AccountBalanceResponse - rand.Seed(time.Now().UnixNano()) dummyUserID := rand.Intn(10^6-10^3) + 10 ^ 3 - rt := sdkrouter.New(config.GetLbrynetServers()) - walletID, err := rt.InitializeWallet(dummyUserID) - require.NoError(t, err) - svc := NewService(rt) + request := newRawRequest(t, "wallet_balance", nil) result := svc.NewCaller("").Call(request) assert.Contains(t, string(result), `"message": "account identifier required"`) + walletID, err := wallet.Create(test.RandServerAddress(t), dummyUserID) + require.NoError(t, err) + hook := logrusTest.NewLocal(Logger.Logger()) result = svc.NewCaller(walletID).Call(request) + var accountBalanceResponse ljsonrpc.AccountBalanceResponse parseRawResponse(t, result, &accountBalanceResponse) assert.EqualValues(t, "0", fmt.Sprintf("%v", accountBalanceResponse.Available)) assert.Equal(t, map[string]interface{}{"wallet_id": fmt.Sprintf("%v", walletID)}, hook.LastEntry().Data["params"]) @@ -268,7 +268,7 @@ func TestCallerCallSDKError(t *testing.T) { func TestCallerCallClientJSONError(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - responses.PrepareJSONWriter(w) + responses.AddJSONContentType(w) w.Write([]byte(`{"method":"version}`)) })) c := NewCaller(ts.URL, "") diff --git a/app/proxy/client_test.go b/app/proxy/client_test.go index 2cd37dad..6301c1bb 100644 --- a/app/proxy/client_test.go +++ b/app/proxy/client_test.go @@ -6,8 +6,9 @@ import ( "time" "github.com/lbryio/lbrytv/app/sdkrouter" - "github.com/lbryio/lbrytv/config" + "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/internal/test" + "github.com/stretchr/testify/require" "github.com/ybbus/jsonrpc" ) @@ -15,18 +16,18 @@ import ( func TestClientCallDoesReloadWallet(t *testing.T) { rand.Seed(time.Now().UnixNano()) dummyUserID := rand.Intn(100) - rt := sdkrouter.New(config.GetLbrynetServers()) + addr := test.RandServerAddress(t) - walletID, err := rt.InitializeWallet(dummyUserID) + walletID, err := wallet.Create(addr, dummyUserID) require.NoError(t, err) - err = rt.UnloadWallet(dummyUserID) + err = wallet.UnloadWallet(addr, dummyUserID) require.NoError(t, err) q, err := NewQuery(newRawRequest(t, "wallet_balance", nil)) require.NoError(t, err) q.SetWalletID(walletID) - c := NewCaller(rt.GetServer(dummyUserID).Address, walletID) + c := NewCaller(addr, walletID) r, err := c.callQueryWithRetry(q) // err = json.Unmarshal(result, response) require.NoError(t, err) diff --git a/app/proxy/const.go b/app/proxy/const.go index 0d1aca48..6782afd2 100644 --- a/app/proxy/const.go +++ b/app/proxy/const.go @@ -140,12 +140,3 @@ var ignoreLog = []string{ MethodAccountBalance, MethodStatus, } - -func shouldLog(method string) bool { - for _, m := range ignoreLog { - if m == method { - return false - } - } - return true -} diff --git a/app/proxy/handlers.go b/app/proxy/handlers.go index d024f6f3..f15b7128 100644 --- a/app/proxy/handlers.go +++ b/app/proxy/handlers.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/lbryio/lbrytv/app/users" + "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/internal/responses" ) @@ -29,6 +30,7 @@ func (rh *RequestHandler) Handle(w http.ResponseWriter, r *http.Request) { proxyHandlerLogger.Log().Errorf("empty request body") return } + body, err := ioutil.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) @@ -41,12 +43,12 @@ func (rh *RequestHandler) Handle(w http.ResponseWriter, r *http.Request) { q, err := NewQuery(body) if err != nil || !methodInList(q.Method(), relaxedMethods) { - retriever := users.NewWalletService(rh.SDKRouter) - auth := users.NewAuthenticator(retriever) + auth := users.NewAuthenticator(rh.SDKRouter) walletID, err = auth.GetWalletID(r) if err != nil { - responses.JSONRPCError(w, err.Error(), rpcErrorCodeUnauthorized) + responses.AddJSONContentType(w) + w.Write(marshalError(err)) monitor.CaptureRequestError(err, r, w) return } @@ -55,15 +57,15 @@ func (rh *RequestHandler) Handle(w http.ResponseWriter, r *http.Request) { c := rh.NewCaller(walletID) rawCallReponse := c.Call(body) - responses.PrepareJSONWriter(w) + responses.AddJSONContentType(w) w.Write(rawCallReponse) } -// HandleOptions returns necessary CORS headers for pre-flight requests to proxy API -func (rh *RequestHandler) HandleOptions(w http.ResponseWriter, r *http.Request) { +// HandleCORS returns necessary CORS headers for pre-flight requests to proxy API +func HandleCORS(w http.ResponseWriter, r *http.Request) { hs := w.Header() hs.Set("Access-Control-Max-Age", "7200") hs.Set("Access-Control-Allow-Origin", "*") - hs.Set("Access-Control-Allow-Headers", "X-Lbry-Auth-Token, Origin, X-Requested-With, Content-Type, Accept") + hs.Set("Access-Control-Allow-Headers", wallet.TokenHeader+", Origin, X-Requested-With, Content-Type, Accept") w.WriteHeader(http.StatusOK) } diff --git a/app/proxy/handlers_test.go b/app/proxy/handlers_test.go index aa3cb08c..a01f1c9e 100644 --- a/app/proxy/handlers_test.go +++ b/app/proxy/handlers_test.go @@ -7,7 +7,7 @@ import ( "net/http/httptest" "testing" - "github.com/lbryio/lbrytv/app/users" + "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/config" "github.com/stretchr/testify/assert" @@ -16,11 +16,11 @@ import ( ) func TestProxyOptions(t *testing.T) { - r, _ := http.NewRequest("OPTIONS", "/api/proxy", nil) + r, err := http.NewRequest("OPTIONS", "/api/proxy", nil) + require.NoError(t, err) rr := httptest.NewRecorder() - handler := NewRequestHandler(svc) - handler.HandleOptions(rr, r) + HandleCORS(rr, r) response := rr.Result() assert.Equal(t, http.StatusOK, response.StatusCode) @@ -62,7 +62,7 @@ func TestProxyDontAuthRelaxedMethods(t *testing.T) { config.Override("InternalAPIHost", ts.URL) r, _ := http.NewRequest("POST", "", bytes.NewBuffer([]byte(newRawRequest(t, "resolve", map[string]string{"urls": "what"})))) - r.Header.Set(users.TokenHeader, "abc") + r.Header.Set(wallet.TokenHeader, "abc") rr := httptest.NewRecorder() handler := NewRequestHandler(svc) diff --git a/app/proxy/main_test.go b/app/proxy/main_test.go index 06f82b20..941b13e3 100644 --- a/app/proxy/main_test.go +++ b/app/proxy/main_test.go @@ -53,7 +53,7 @@ func testFuncTeardown() { func launchDummyAPIServer(response []byte) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - responses.PrepareJSONWriter(w) + responses.AddJSONContentType(w) w.Write(response) })) } @@ -61,7 +61,7 @@ func launchDummyAPIServer(response []byte) *httptest.Server { func launchDummyAPIServerDelayed(response []byte, delayMsec time.Duration) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(delayMsec * time.Millisecond) - responses.PrepareJSONWriter(w) + responses.AddJSONContentType(w) w.Write(response) })) } diff --git a/app/publish/handler_test.go b/app/publish/handler_test.go index a39c7f06..a92df10f 100644 --- a/app/publish/handler_test.go +++ b/app/publish/handler_test.go @@ -13,6 +13,7 @@ import ( "testing" "github.com/lbryio/lbrytv/app/users" + "github.com/lbryio/lbrytv/app/wallet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ybbus/jsonrpc" @@ -35,10 +36,10 @@ func (p *DummyPublisher) Publish(filePath, accountID string, rawQuery []byte) [] func TestUploadHandler(t *testing.T) { req := CreatePublishRequest(t, []byte("test file")) - req.Header.Set(users.TokenHeader, "uPldrToken") + req.Header.Set(wallet.TokenHeader, "uPldrToken") rr := httptest.NewRecorder() - authenticator := users.NewAuthenticator(&users.TestUserRetriever{WalletID: "UPldrAcc", Token: "uPldrToken"}) + authenticator := &users.Authenticator{Retriever: users.DummyRetriever("uPldrToken", "UPldrAcc")} publisher := &DummyPublisher{} pubHandler, err := NewUploadHandler(UploadOpts{Path: os.TempDir(), Publisher: publisher}) assert.NoError(t, err) @@ -65,7 +66,7 @@ func TestUploadHandlerAuthRequired(t *testing.T) { req := CreatePublishRequest(t, []byte("test file")) rr := httptest.NewRecorder() - authenticator := users.NewAuthenticator(&users.TestUserRetriever{}) + authenticator := &users.Authenticator{Retriever: users.DummyRetriever("", "")} publisher := &DummyPublisher{} pubHandler, err := NewUploadHandler(UploadOpts{Path: os.TempDir(), Publisher: publisher}) assert.NoError(t, err) @@ -90,12 +91,12 @@ func TestUploadHandlerSystemError(t *testing.T) { writer := multipart.NewWriter(body) - fileBody, err := writer.CreateFormFile(FileFieldName, "lbry_auto_test_file") + fileBody, err := writer.CreateFormFile(fileFieldName, "lbry_auto_test_file") require.NoError(t, err) _, err = io.Copy(fileBody, readSeeker) require.NoError(t, err) - jsonPayload, err := writer.CreateFormField(JSONRPCFieldName) + jsonPayload, err := writer.CreateFormField(jsonRPCFieldName) require.NoError(t, err) jsonPayload.Write([]byte(expectedStreamCreateRequest)) @@ -104,11 +105,11 @@ func TestUploadHandlerSystemError(t *testing.T) { req, err := http.NewRequest("POST", "/", bytes.NewReader(body.Bytes())) require.NoError(t, err) - req.Header.Set(users.TokenHeader, "uPldrToken") + req.Header.Set(wallet.TokenHeader, "uPldrToken") req.Header.Set("Content-Type", writer.FormDataContentType()) rr := httptest.NewRecorder() - authenticator := users.NewAuthenticator(&users.TestUserRetriever{WalletID: "UPldrAcc", Token: "uPldrToken"}) + authenticator := &users.Authenticator{Retriever: users.DummyRetriever("uPldrToken", "UPldrAcc")} publisher := &DummyPublisher{} pubHandler, err := NewUploadHandler(UploadOpts{Path: os.TempDir(), Publisher: publisher}) assert.NoError(t, err) diff --git a/app/publish/publish.go b/app/publish/publish.go index 31433663..fedc12ce 100644 --- a/app/publish/publish.go +++ b/app/publish/publish.go @@ -17,11 +17,11 @@ import ( "github.com/gorilla/mux" ) -// FileFieldName refers to the POST field containing file upload -const FileFieldName = "file" +// fileFieldName refers to the POST field containing file upload +const fileFieldName = "file" -// JSONRPCFieldName is a name of the POST field containing JSONRPC request accompanying the uploaded file -const JSONRPCFieldName = "json_payload" +// jsonRPCFieldName is a name of the POST field containing JSONRPC request accompanying the uploaded file +const jsonRPCFieldName = "json_payload" const fileNameParam = "file_path" @@ -56,6 +56,7 @@ func NewUploadHandler(opts UploadOpts) (*UploadHandler, error) { publisher Publisher uploadPath string ) + if opts.ProxyService != nil { publisher = &LbrynetPublisher{Service: opts.ProxyService} } else if opts.Publisher != nil { @@ -111,7 +112,7 @@ func (h UploadHandler) Handle(w http.ResponseWriter, r *users.AuthenticatedReque return } - response := h.Publisher.Publish(f.Name(), r.WalletID, []byte(r.FormValue(JSONRPCFieldName))) + response := h.Publisher.Publish(f.Name(), r.WalletID, []byte(r.FormValue(jsonRPCFieldName))) if err := os.Remove(f.Name()); err != nil { monitor.CaptureException(err, map[string]string{"file_path": f.Name()}) @@ -123,8 +124,8 @@ func (h UploadHandler) Handle(w http.ResponseWriter, r *users.AuthenticatedReque // CanHandle checks if http.Request contains POSTed data in an accepted format. // Supposed to be used in gorilla mux router MatcherFunc. func (h UploadHandler) CanHandle(r *http.Request, _ *mux.RouteMatch) bool { - _, _, err := r.FormFile(FileFieldName) - payload := r.FormValue(JSONRPCFieldName) + _, _, err := r.FormFile(fileFieldName) + payload := r.FormValue(jsonRPCFieldName) return err != http.ErrMissingFile && payload != "" } @@ -147,7 +148,7 @@ func (h UploadHandler) preparePath(walletID string) (string, error) { func (h UploadHandler) saveFile(r *users.AuthenticatedRequest) (*os.File, error) { log := logger.LogF(monitor.F{"account_id": r.WalletID}) - file, header, err := r.FormFile(FileFieldName) + file, header, err := r.FormFile(fileFieldName) if err != nil { return nil, err } diff --git a/app/publish/publish_test.go b/app/publish/publish_test.go index 36fc11e7..0eeef47c 100644 --- a/app/publish/publish_test.go +++ b/app/publish/publish_test.go @@ -10,9 +10,11 @@ import ( "github.com/lbryio/lbrytv/app/proxy" "github.com/lbryio/lbrytv/app/sdkrouter" - "github.com/lbryio/lbrytv/app/users" + "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/config" + "github.com/lbryio/lbrytv/internal/responses" "github.com/lbryio/lbrytv/internal/storage" + "github.com/lbryio/lbrytv/internal/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -39,15 +41,28 @@ func TestLbrynetPublisher(t *testing.T) { c.SetDefaultConnection() defer connCleanup() - ts := users.StartAuthenticatingAPIServer(751365) + reqChan := test.ReqChan() + ts := test.MockHTTPServer(reqChan) defer ts.Close() + go func() { + req := <-reqChan + responses.AddJSONContentType(req.W) + ts.NextResponse <- fmt.Sprintf(`{ + "success": true, + "error": null, + "data": { + "user_id": %v, + "has_verified_email": true + } + }`, 751365) + }() + config.Override("InternalAPIHost", ts.URL) defer config.RestoreOverridden() rt := sdkrouter.New(config.GetLbrynetServers()) p := &LbrynetPublisher{proxy.NewService(rt)} - walletSvc := users.NewWalletService(rt) - u, err := walletSvc.Retrieve(users.Query{Token: authToken}) + u, err := wallet.GetUserWithWallet(rt, authToken, "") require.NoError(t, err) data := []byte("test file") diff --git a/app/publish/testing.go b/app/publish/testing.go index f927c3e2..02aa5ce3 100644 --- a/app/publish/testing.go +++ b/app/publish/testing.go @@ -17,12 +17,12 @@ func CreatePublishRequest(t *testing.T, data []byte) *http.Request { writer := multipart.NewWriter(body) - fileBody, err := writer.CreateFormFile(FileFieldName, "lbry_auto_test_file") + fileBody, err := writer.CreateFormFile(fileFieldName, "lbry_auto_test_file") require.NoError(t, err) _, err = io.Copy(fileBody, readSeeker) require.NoError(t, err) - jsonPayload, err := writer.CreateFormField(JSONRPCFieldName) + jsonPayload, err := writer.CreateFormField(jsonRPCFieldName) require.NoError(t, err) jsonPayload.Write([]byte(expectedStreamCreateRequest)) diff --git a/app/sdkrouter/sdkrouter.go b/app/sdkrouter/sdkrouter.go index 4b1ede9b..032e7a09 100644 --- a/app/sdkrouter/sdkrouter.go +++ b/app/sdkrouter/sdkrouter.go @@ -11,7 +11,6 @@ import ( "sync" "time" - "github.com/lbryio/lbrytv/internal/lbrynet" "github.com/lbryio/lbrytv/internal/metrics" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/models" @@ -217,86 +216,6 @@ func (r *Router) Client(userID int) *ljsonrpc.Client { return c } -// InitializeWallet creates a wallet that can be immediately used in subsequent commands. -// It can recover from errors like existing wallets, but if a wallet is known to exist -// (eg. a wallet ID stored in the database already), loadWallet() should be called instead. -func (r *Router) InitializeWallet(userID int) (string, error) { - wallet, err := r.createWallet(userID) - if err == nil { - return wallet.ID, nil - } - - walletID := WalletID(userID) - log := logger.LogF(monitor.F{"user_id": userID}) - - if errors.Is(err, lbrynet.ErrWalletExists) { - log.Warn(err.Error()) - return walletID, nil - } - - if errors.Is(err, lbrynet.ErrWalletNeedsLoading) { - log.Info(err.Error()) - wallet, err = r.loadWallet(userID) - if err != nil { - if errors.Is(err, lbrynet.ErrWalletAlreadyLoaded) { - log.Info(err.Error()) - return walletID, nil - } - return "", err - } - return wallet.ID, nil - } - - log.Errorf("don't know how to recover from error: %v", err) - return "", err -} - -// createWallet creates a new wallet on the LbrynetServer. -// Returned error doesn't necessarily mean that the wallet is not operational: -// -// if errors.Is(err, lbrynet.WalletExists) { -// // Okay to proceed with the account -// } -// -// if errors.Is(err, lbrynet.WalletNeedsLoading) { -// // loadWallet() needs to be called before the wallet can be used -// } -func (r *Router) createWallet(userID int) (*ljsonrpc.Wallet, error) { - wallet, err := r.Client(userID).WalletCreate(WalletID(userID), &ljsonrpc.WalletCreateOpts{ - SkipOnStartup: true, CreateAccount: true, SingleKey: true}) - if err != nil { - return nil, lbrynet.NewWalletError(userID, err) - } - logger.LogF(monitor.F{"user_id": userID}).Info("wallet created") - return wallet, nil -} - -// loadWallet loads an existing wallet in the LbrynetServer. -// May return errors: -// WalletAlreadyLoaded - wallet is already loaded and operational -// WalletNotFound - wallet file does not exist and won't be loaded. -func (r *Router) loadWallet(userID int) (*ljsonrpc.Wallet, error) { - wallet, err := r.Client(userID).WalletAdd(WalletID(userID)) - if err != nil { - return nil, lbrynet.NewWalletError(userID, err) - } - logger.LogF(monitor.F{"user_id": userID}).Info("wallet loaded") - return wallet, nil -} - -// UnloadWallet unloads an existing wallet from the LbrynetServer. -// May return errors: -// WalletAlreadyLoaded - wallet is already loaded and operational -// WalletNotFound - wallet file does not exist and won't be loaded. -func (r *Router) UnloadWallet(userID int) error { - _, err := r.Client(userID).WalletRemove(WalletID(userID)) - if err != nil { - return lbrynet.NewWalletError(userID, err) - } - logger.LogF(monitor.F{"user_id": userID}).Info("wallet unloaded") - return nil -} - // WalletID formats user ID to use as an LbrynetServer wallet ID. func WalletID(userID int) string { return fmt.Sprintf("lbrytv-id.%d.wallet", userID) diff --git a/app/sdkrouter/sdkrouter_test.go b/app/sdkrouter/sdkrouter_test.go index da4968d0..94c079b1 100644 --- a/app/sdkrouter/sdkrouter_test.go +++ b/app/sdkrouter/sdkrouter_test.go @@ -1,20 +1,15 @@ package sdkrouter import ( - "errors" "fmt" - "math/rand" "os" "testing" - "time" "github.com/lbryio/lbrytv/config" - "github.com/lbryio/lbrytv/internal/lbrynet" "github.com/lbryio/lbrytv/internal/storage" "github.com/lbryio/lbrytv/internal/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { @@ -102,41 +97,3 @@ func TestLeastLoaded(t *testing.T) { assert.Equal(t, "srv3", r.LeastLoaded().Name) } - -func TestInitializeWallet(t *testing.T) { - rand.Seed(time.Now().UnixNano()) - userID := rand.Int() - r := New(config.GetLbrynetServers()) - - walletID, err := r.InitializeWallet(userID) - require.NoError(t, err) - assert.Equal(t, walletID, WalletID(userID)) - - err = r.UnloadWallet(userID) - require.NoError(t, err) - - walletID, err = r.InitializeWallet(userID) - require.NoError(t, err) - assert.Equal(t, walletID, WalletID(userID)) -} - -func TestCreateWalletLoadWallet(t *testing.T) { - rand.Seed(time.Now().UnixNano()) - userID := rand.Int() - r := New(config.GetLbrynetServers()) - - wallet, err := r.createWallet(userID) - require.NoError(t, err) - assert.Equal(t, wallet.ID, WalletID(userID)) - - wallet, err = r.createWallet(userID) - require.NotNil(t, err) - assert.True(t, errors.Is(err, lbrynet.ErrWalletExists)) - - err = r.UnloadWallet(userID) - require.NoError(t, err) - - wallet, err = r.loadWallet(userID) - require.NoError(t, err) - assert.Equal(t, wallet.ID, WalletID(userID)) -} diff --git a/app/users/authenticator.go b/app/users/authenticator.go index 93085b81..d3f236bf 100644 --- a/app/users/authenticator.go +++ b/app/users/authenticator.go @@ -3,43 +3,62 @@ package users import ( "net/http" + "github.com/lbryio/lbrytv/app/sdkrouter" + "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/internal/monitor" + "github.com/lbryio/lbrytv/models" ) const GenericRetrievalErr = "unable to retrieve user" var logger = monitor.NewModuleLogger("auth") -type Authenticator struct { - retriever Retriever -} - type AuthenticatedRequest struct { *http.Request WalletID string AuthError error } -type AuthenticatedFunc func(http.ResponseWriter, *AuthenticatedRequest) +// AuthFailed is a helper to see if there was an error authenticating user. +func (r *AuthenticatedRequest) AuthFailed() bool { + return r.AuthError != nil +} + +// IsAuthenticated is a helper to see if a user was authenticated. +// If it is false, AuthError might be provided (in case user retriever has errored) +// or be nil if no auth token was present in headers. +func (r *AuthenticatedRequest) IsAuthenticated() bool { + return r.WalletID != "" +} + +// Retriever is an interface for user retrieval by internal-apis auth token +type UserRetriever func(token, metaRemoteIP string) (*models.User, error) + +type Authenticator struct { + Retriever UserRetriever +} // NewAuthenticator provides HTTP handler wrapping methods // and should be initialized with an object that allows user retrieval. -func NewAuthenticator(retriever Retriever) *Authenticator { - return &Authenticator{retriever} +func NewAuthenticator(rt *sdkrouter.Router) *Authenticator { + return &Authenticator{ + Retriever: func(token, metaRemoteIP string) (user *models.User, err error) { + return wallet.GetUserWithWallet(rt, token, metaRemoteIP) + }, + } } // GetWalletID retrieves user token from HTTP headers and subsequently // an SDK account ID from Retriever. func (a *Authenticator) GetWalletID(r *http.Request) (string, error) { - if token, ok := r.Header[TokenHeader]; ok { + if token, ok := r.Header[wallet.TokenHeader]; ok { ip := GetIPAddressForRequest(r) - u, err := a.retriever.Retrieve(Query{Token: token[0], MetaRemoteIP: ip}) - log := logger.LogF(monitor.F{"ip": ip}) + user, err := a.Retriever(token[0], ip) if err != nil { - log.Debugf("failed to authenticate user") + logger.LogF(monitor.F{"ip": ip}).Debugf("failed to authenticate user") return "", err - } else if u != nil { - return u.WalletID, nil + } else if user != nil { + return user.WalletID, nil } } return "", nil @@ -47,10 +66,10 @@ func (a *Authenticator) GetWalletID(r *http.Request) (string, error) { // Wrap result can be supplied to all functions that accept http.HandleFunc, // supplied function will be wrapped and called with AuthenticatedRequest instead of http.Request. -func (a *Authenticator) Wrap(wrapped AuthenticatedFunc) http.HandlerFunc { +func (a *Authenticator) Wrap(wrapped func(http.ResponseWriter, *AuthenticatedRequest)) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - WalletID, err := a.GetWalletID(r) ar := &AuthenticatedRequest{Request: r} + WalletID, err := a.GetWalletID(r) if err != nil { ar.AuthError = err } else { @@ -59,15 +78,3 @@ func (a *Authenticator) Wrap(wrapped AuthenticatedFunc) http.HandlerFunc { wrapped(w, ar) } } - -// AuthFailed is a helper to see if there was an error authenticating user. -func (r *AuthenticatedRequest) AuthFailed() bool { - return r.AuthError != nil -} - -// IsAuthenticated is a helper to see if a user was authenticated. -// If it is false, AuthError might be provided (in case user retriever has errored) -// or be nil if no auth token was present in headers. -func (r *AuthenticatedRequest) IsAuthenticated() bool { - return r.WalletID != "" -} diff --git a/app/users/authenticator_test.go b/app/users/authenticator_test.go index a543e401..f97381ae 100644 --- a/app/users/authenticator_test.go +++ b/app/users/authenticator_test.go @@ -7,32 +7,14 @@ import ( "net/http/httptest" "testing" + "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/models" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -type DummyRetriever struct { - remoteIP string -} - -func (r *DummyRetriever) Retrieve(q Query) (*models.User, error) { - r.remoteIP = q.MetaRemoteIP - if q.Token == "XyZ" { - return &models.User{WalletID: "aBc"}, nil - } - return nil, errors.New("cannot authenticate") -} - -type UnverifiedRetriever struct { - remoteIP string -} - -func (r *UnverifiedRetriever) Retrieve(q Query) (*models.User, error) { - return nil, nil -} - -func AuthenticatedHandler(w http.ResponseWriter, r *AuthenticatedRequest) { +func authedHandler(w http.ResponseWriter, r *AuthenticatedRequest) { if r.IsAuthenticated() { w.WriteHeader(http.StatusAccepted) w.Write([]byte(r.WalletID)) @@ -43,43 +25,55 @@ func AuthenticatedHandler(w http.ResponseWriter, r *AuthenticatedRequest) { } func TestAuthenticator(t *testing.T) { - retriever := &DummyRetriever{} - r, _ := http.NewRequest("GET", "/api/proxy", nil) - r.Header.Set(TokenHeader, "XyZ") + r, err := http.NewRequest("GET", "/api/proxy", nil) + require.NoError(t, err) + r.Header.Set(wallet.TokenHeader, "XyZ") r.Header.Set("X-Forwarded-For", "8.8.8.8") - rr := httptest.NewRecorder() - authenticator := NewAuthenticator(retriever) - - http.HandlerFunc(authenticator.Wrap(AuthenticatedHandler)).ServeHTTP(rr, r) + var receivedRemoteIP string + authenticator := &Authenticator{ + Retriever: func(token, ip string) (*models.User, error) { + receivedRemoteIP = ip + if token == "XyZ" { + return &models.User{WalletID: "aBc"}, nil + } + return nil, errors.New(GenericRetrievalErr) + }, + } + rr := httptest.NewRecorder() + authenticator.Wrap(authedHandler).ServeHTTP(rr, r) response := rr.Result() - body, _ := ioutil.ReadAll(response.Body) + body, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) assert.Equal(t, "aBc", string(body)) - assert.Equal(t, "8.8.8.8", retriever.remoteIP) + assert.Equal(t, "8.8.8.8", receivedRemoteIP) } func TestAuthenticatorFailure(t *testing.T) { - r, _ := http.NewRequest("GET", "/api/proxy", nil) - r.Header.Set(TokenHeader, "ALSDJ") + r, err := http.NewRequest("GET", "/api/proxy", nil) + require.NoError(t, err) + r.Header.Set(wallet.TokenHeader, "ALSDJ") rr := httptest.NewRecorder() - authenticator := NewAuthenticator(&DummyRetriever{}) + authenticator := &Authenticator{Retriever: DummyRetriever("XyZ", "")} - http.HandlerFunc(authenticator.Wrap(AuthenticatedHandler)).ServeHTTP(rr, r) + authenticator.Wrap(authedHandler).ServeHTTP(rr, r) response := rr.Result() - body, _ := ioutil.ReadAll(response.Body) - assert.Equal(t, "cannot authenticate", string(body)) + body, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + assert.Equal(t, GenericRetrievalErr, string(body)) assert.Equal(t, http.StatusForbidden, response.StatusCode) } func TestAuthenticatorGetWalletIDUnverifiedUser(t *testing.T) { - r, _ := http.NewRequest("GET", "/api/proxy", nil) - r.Header.Set(TokenHeader, "zzz") + r, err := http.NewRequest("GET", "/api/proxy", nil) + require.NoError(t, err) + r.Header.Set(wallet.TokenHeader, "zzz") - a := NewAuthenticator(&UnverifiedRetriever{}) + a := &Authenticator{Retriever: func(token, ip string) (*models.User, error) { return nil, nil }} - wid, err := a.GetWalletID(r) + walletID, err := a.GetWalletID(r) assert.NoError(t, err) - assert.Equal(t, "", wid) + assert.Equal(t, "", walletID) } diff --git a/app/users/testing.go b/app/users/testing.go index 32b36379..3d047082 100644 --- a/app/users/testing.go +++ b/app/users/testing.go @@ -2,65 +2,15 @@ package users import ( "errors" - "fmt" - "net/http" - "net/http/httptest" - "github.com/lbryio/lbrytv/internal/responses" "github.com/lbryio/lbrytv/models" ) -const userHasVerifiedEmailResponse = `{ - "success": true, - "error": null, - "data": { - "user_id": %v, - "has_verified_email": true +func DummyRetriever(userToken, walletID string) UserRetriever { + return func(token, ip string) (*models.User, error) { + if userToken == "" || userToken == token { + return &models.User{WalletID: walletID}, nil + } + return nil, errors.New(GenericRetrievalErr) } -}` - -const userDoesntHaveVerifiedEmailResponse = `{ - "success": true, - "error": null, - "data": { - "user_id": %v, - "has_verified_email": false - } -}` - -// TestUserRetriever is a helper allowing to test API endpoints that require authentication -// without actually creating DB records. -type TestUserRetriever struct { - WalletID string - Token string -} - -// Retrieve returns WalletID set during TestUserRetriever creation, -// checking it against TestUserRetriever's Token field if one was supplied. -func (r *TestUserRetriever) Retrieve(q Query) (*models.User, error) { - if r.Token == "" || r.Token == q.Token { - return &models.User{WalletID: r.WalletID}, nil - } - return nil, errors.New(GenericRetrievalErr) -} - -func StartDummyAPIServer(response []byte) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - responses.PrepareJSONWriter(w) - w.Write(response) - })) -} - -func StartAuthenticatingAPIServer(userID int) *httptest.Server { - response := fmt.Sprintf(userHasVerifiedEmailResponse, userID) - return StartDummyAPIServer([]byte(response)) -} - -func StartEasyAPIServer() *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t := r.PostFormValue("auth_token") - reply := fmt.Sprintf(userHasVerifiedEmailResponse, t) - responses.PrepareJSONWriter(w) - w.Write([]byte(reply)) - })) } diff --git a/app/users/testing_test.go b/app/users/testing_test.go index df9b692d..5ebc199b 100644 --- a/app/users/testing_test.go +++ b/app/users/testing_test.go @@ -5,39 +5,35 @@ import ( "net/http" "testing" + "github.com/lbryio/lbrytv/app/wallet" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestTestUserRetrieverGetWalletID(t *testing.T) { - var ( - testAuth *Authenticator - r *http.Request - err error - a string - ) - - testAuth = NewAuthenticator(&TestUserRetriever{WalletID: "123"}) - r, _ = http.NewRequest("GET", "/", nil) - r.Header.Set(TokenHeader, "XyZ") - a, err = testAuth.GetWalletID(r) + testAuth := &Authenticator{Retriever: DummyRetriever("", "123")} + r, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + r.Header.Set(wallet.TokenHeader, "XyZ") + a, err := testAuth.GetWalletID(r) assert.NoError(t, err) assert.Equal(t, "123", a) r, _ = http.NewRequest("GET", "/", nil) - r.Header.Set(TokenHeader, "aBc") + r.Header.Set(wallet.TokenHeader, "aBc") a, err = testAuth.GetWalletID(r) assert.NoError(t, err) assert.Equal(t, "123", a) - testAuth = NewAuthenticator(&TestUserRetriever{WalletID: "123", Token: "XyZ"}) + testAuth = &Authenticator{Retriever: DummyRetriever("XyZ", "123")} r, _ = http.NewRequest("GET", "/", nil) - r.Header.Set(TokenHeader, "XyZ") + r.Header.Set(wallet.TokenHeader, "XyZ") a, err = testAuth.GetWalletID(r) assert.NoError(t, err) assert.Equal(t, "123", a) r, _ = http.NewRequest("GET", "/", nil) - r.Header.Set(TokenHeader, "aBc") + r.Header.Set(wallet.TokenHeader, "aBc") a, err = testAuth.GetWalletID(r) assert.Equal(t, errors.New(GenericRetrievalErr), err) assert.Equal(t, "", a) diff --git a/app/users/users.go b/app/users/users.go deleted file mode 100644 index 4238237a..00000000 --- a/app/users/users.go +++ /dev/null @@ -1,130 +0,0 @@ -package users - -import ( - "database/sql" - "fmt" - - "github.com/lbryio/lbrytv/app/sdkrouter" - "github.com/lbryio/lbrytv/internal/monitor" - "github.com/lbryio/lbrytv/models" - - "github.com/lib/pq" - xerrors "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "github.com/volatiletech/sqlboiler/boil" -) - -// WalletService retrieves user wallet data. -type WalletService struct { - Logger monitor.ModuleLogger - Router *sdkrouter.Router -} - -// TokenHeader is the name of HTTP header which is supplied by client and should contain internal-api auth_token. -const TokenHeader string = "X-Lbry-Auth-Token" -const errUniqueViolation = "23505" - -// Retriever is an interface for user retrieval by internal-apis auth token -type Retriever interface { - Retrieve(query Query) (*models.User, error) -} - -// Query contains queried user details and optional metadata about the request -type Query struct { - Token string - MetaRemoteIP string -} - -// NewWalletService returns WalletService instance for retrieving or creating wallet-based user records and accounts. -func NewWalletService(r *sdkrouter.Router) *WalletService { - return &WalletService{Logger: monitor.NewModuleLogger("users"), Router: r} -} - -func (s *WalletService) createDBUser(id int) (*models.User, error) { - log := s.Logger.LogF(monitor.F{"id": id}) - - u := &models.User{ID: id} - err := u.InsertG(boil.Infer()) - if err != nil { - // Check if we encountered a primary key violation, it would mean another routine - // fired from another request has managed to create a user before us so we should try retrieving it again. - switch baseErr := xerrors.Cause(err).(type) { - case *pq.Error: - if baseErr.Code == errUniqueViolation && baseErr.Column == "users_pkey" { - log.Debug("user creation conflict, trying to retrieve the local user again") - return getDBUser(id) - } - default: - log.Error("unknown error encountered while creating user: ", err) - return nil, err - } - } - return u, nil -} - -// Retrieve gets user by internal-apis auth token provided in the supplied Query. -func (s *WalletService) Retrieve(q Query) (*models.User, error) { - token := q.Token - log := s.Logger.LogF(monitor.F{monitor.TokenF: token}) - - remoteUser, err := getRemoteUser(token, q.MetaRemoteIP) - if err != nil { - msg := "cannot authenticate user with internal-apis: %v" - log.Errorf(msg, err) - return nil, fmt.Errorf(msg, err) - } - if !remoteUser.HasVerifiedEmail { - return nil, nil - } - - log.Data["id"] = remoteUser.ID - log.Data["has_email"] = remoteUser.HasVerifiedEmail - - localUser, err := getDBUser(remoteUser.ID) - if err != nil && err != sql.ErrNoRows { - return nil, err - } else if err == sql.ErrNoRows { - log.Infof("user not found in the database, creating") - localUser, err = s.createDBUser(remoteUser.ID) - if err != nil { - return nil, err - } - } else if localUser.WalletID == "" { - // This scenario may happen for legacy users who are present in the database but don't have a wallet yet - log.Warnf("user %d doesn't have wallet ID set", localUser.ID) - } - - if localUser.WalletID == "" { - err := createWalletForUser(localUser, s.Router, log) - if err != nil { - return nil, err - } - } - - return localUser, nil -} - -func createWalletForUser(user *models.User, router *sdkrouter.Router, log *logrus.Entry) error { - // either a new user or a legacy user without a wallet - walletID, err := router.InitializeWallet(user.ID) - if err != nil { - return err - } - - log.Data["wallet_id"] = walletID - log.Info("saving wallet ID to user record") - - user.WalletID = walletID - - server := router.GetServer(user.ID) - if server.ID > 0 { // Ensure server is from DB - user.LbrynetServerID.SetValid(server.ID) - } - - _, err = user.UpdateG(boil.Infer()) - return err -} - -func getDBUser(id int) (*models.User, error) { - return models.Users(models.UserWhere.ID.EQ(id)).OneG() -} diff --git a/app/users/remote.go b/app/wallet/remote.go similarity index 77% rename from app/users/remote.go rename to app/wallet/remote.go index ce797326..b6a6ccbe 100644 --- a/app/users/remote.go +++ b/app/wallet/remote.go @@ -1,9 +1,8 @@ -package users +package wallet import ( "time" - "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/metrics" "github.com/lbryio/lbry.go/v2/extras/lbryinc" @@ -15,9 +14,9 @@ type remoteUser struct { HasVerifiedEmail bool } -func getRemoteUser(token string, remoteIP string) (*remoteUser, error) { +func getRemoteUser(url, token string, remoteIP string) (remoteUser, error) { c := lbryinc.NewClient(token, &lbryinc.ClientOpts{ - ServerAddress: config.GetInternalAPIHost(), + ServerAddress: url, RemoteIP: remoteIP, }) @@ -28,12 +27,12 @@ func getRemoteUser(token string, remoteIP string) (*remoteUser, error) { if err != nil { // No user found in internal-apis database, give up at this point metrics.IAPIAuthFailedDurations.Observe(duration) - return nil, err + return remoteUser{}, err } metrics.IAPIAuthSuccessDurations.Observe(duration) - return &remoteUser{ + return remoteUser{ ID: int(r["user_id"].(float64)), HasVerifiedEmail: r["has_verified_email"].(bool), }, nil diff --git a/app/wallet/wallet.go b/app/wallet/wallet.go new file mode 100644 index 00000000..2103e5c8 --- /dev/null +++ b/app/wallet/wallet.go @@ -0,0 +1,207 @@ +package wallet + +import ( + "database/sql" + "errors" + "fmt" + + ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" + "github.com/lbryio/lbrytv/app/sdkrouter" + "github.com/lbryio/lbrytv/config" + "github.com/lbryio/lbrytv/internal/lbrynet" + "github.com/lbryio/lbrytv/internal/monitor" + "github.com/lbryio/lbrytv/models" + + "github.com/lib/pq" + pkgerrors "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/volatiletech/sqlboiler/boil" +) + +var logger = monitor.NewModuleLogger("wallet") + +func DisableLogger() { logger.Disable() } // for testing + +// TokenHeader is the name of HTTP header which is supplied by client and should contain internal-api auth_token. +const TokenHeader = "X-Lbry-Auth-Token" +const pgUniqueConstraintViolation = "23505" + +// Retrieve gets user by internal-apis auth token. If the user does not have a wallet yet, they +// are assigned an SDK and a wallet is created for them on that SDK. +func GetUserWithWallet(rt *sdkrouter.Router, token, metaRemoteIP string) (*models.User, error) { + log := logger.LogF(monitor.F{monitor.TokenF: token}) + + remoteUser, err := getRemoteUser(config.GetInternalAPIHost(), token, metaRemoteIP) + if err != nil { + msg := "cannot authenticate user with internal-apis: %v" + log.Errorf(msg, err) + return nil, fmt.Errorf(msg, err) + } + if !remoteUser.HasVerifiedEmail { + return nil, nil + } + + log.Data["remote_user_id"] = remoteUser.ID + log.Data["has_email"] = remoteUser.HasVerifiedEmail + + localUser, err := getOrCreateLocalUser(remoteUser.ID, log) + if err != nil { + return nil, err + } + + if localUser.WalletID == "" { + log := logger.LogF(monitor.F{monitor.TokenF: token}) + err := assignSDKServerToUser(localUser, rt, log) + if err != nil { + return nil, err + } + } + + return localUser, nil +} + +func getOrCreateLocalUser(remoteUserID int, log *logrus.Entry) (*models.User, error) { + localUser, err := getDBUser(remoteUserID) + if err != nil && err != sql.ErrNoRows { + return nil, err + } else if err == sql.ErrNoRows { + log.Infof("user not found in the database, creating") + localUser, err = createDBUser(remoteUserID) + if err != nil { + return nil, err + } + } else if localUser.WalletID == "" { + // This scenario may happen for legacy users who are present in the database but don't have a wallet yet + log.Warnf("user %d doesn't have wallet ID set", localUser.ID) + } + + return localUser, nil +} + +func assignSDKServerToUser(user *models.User, router *sdkrouter.Router, log *logrus.Entry) error { + server := router.LeastLoaded() + if server.ID > 0 { // Ensure server is from DB + user.LbrynetServerID.SetValid(server.ID) + } else { + // THIS SERVER CAME FROM A CONFIG FILE (prolly for testing) + // TODO: handle this case better + //return fmt.Errorf("user %d is getting a wallet server with no ID", user.ID) + } + + walletID, err := Create(server.Address, user.ID) + if err != nil { + return err + } + + log.Infof("assigning sdk %s to user %d", server.Address, user.ID) + user.WalletID = walletID + _, err = user.UpdateG(boil.Infer()) + return err +} + +func createDBUser(id int) (*models.User, error) { + log := logger.LogF(monitor.F{"id": id}) + + u := &models.User{ID: id} + err := u.InsertG(boil.Infer()) + if err == nil { + return u, nil + } + + // Check if we encountered a primary key violation, it would mean another routine + // fired from another request has managed to create a user before us so we should try retrieving it again. + switch baseErr := pkgerrors.Cause(err).(type) { + case *pq.Error: + if baseErr.Code == pgUniqueConstraintViolation && baseErr.Column == "users_pkey" { + log.Debug("user creation conflict, trying to retrieve the local user again") + return getDBUser(id) + } + } + + log.Error("unknown error encountered while creating user: ", err) + return nil, err +} + +func getDBUser(id int) (*models.User, error) { + return models.Users(models.UserWhere.ID.EQ(id)).OneG() +} + +// Create creates a wallet on an sdk that can be immediately used in subsequent commands. +// It can recover from errors like existing wallets, but if a wallet is known to exist +// (eg. a wallet ID stored in the database already), loadWallet() should be called instead. +func Create(serverAddress string, userID int) (string, error) { + wallet, err := createWallet(serverAddress, userID) + if err == nil { + return wallet.ID, nil + } + + walletID := sdkrouter.WalletID(userID) + log := logger.LogF(monitor.F{"user_id": userID, "sdk": serverAddress}) + + if errors.Is(err, lbrynet.ErrWalletExists) { + log.Warn(err.Error()) + return walletID, nil + } + + if errors.Is(err, lbrynet.ErrWalletNeedsLoading) { + log.Info(err.Error()) + wallet, err = loadWallet(serverAddress, userID) + if err != nil { + if errors.Is(err, lbrynet.ErrWalletAlreadyLoaded) { + log.Info(err.Error()) + return walletID, nil + } + return "", err + } + return wallet.ID, nil + } + + log.Errorf("don't know how to recover from error: %v", err) + return "", err +} + +// createWallet creates a new wallet on the LbrynetServer. +// Returned error doesn't necessarily mean that the wallet is not operational: +// +// if errors.Is(err, lbrynet.WalletExists) { +// // Okay to proceed with the account +// } +// +// if errors.Is(err, lbrynet.WalletNeedsLoading) { +// // loadWallet() needs to be called before the wallet can be used +// } +func createWallet(addr string, userID int) (*ljsonrpc.Wallet, error) { + wallet, err := ljsonrpc.NewClient(addr).WalletCreate(sdkrouter.WalletID(userID), &ljsonrpc.WalletCreateOpts{ + SkipOnStartup: true, CreateAccount: true, SingleKey: true}) + if err != nil { + return nil, lbrynet.NewWalletError(userID, err) + } + logger.LogF(monitor.F{"user_id": userID, "sdk": addr}).Info("wallet created") + return wallet, nil +} + +// loadWallet loads an existing wallet in the LbrynetServer. +// May return errors: +// WalletAlreadyLoaded - wallet is already loaded and operational +// WalletNotFound - wallet file does not exist and won't be loaded. +func loadWallet(addr string, userID int) (*ljsonrpc.Wallet, error) { + wallet, err := ljsonrpc.NewClient(addr).WalletAdd(sdkrouter.WalletID(userID)) + if err != nil { + return nil, lbrynet.NewWalletError(userID, err) + } + logger.LogF(monitor.F{"user_id": userID, "sdk": addr}).Info("wallet loaded") + return wallet, nil +} + +// UnloadWallet unloads an existing wallet from the LbrynetServer. +// May return errors: +// WalletAlreadyLoaded - wallet is already loaded and operational +// WalletNotFound - wallet file does not exist and won't be loaded. +func UnloadWallet(addr string, userID int) error { + _, err := ljsonrpc.NewClient(addr).WalletRemove(sdkrouter.WalletID(userID)) + if err != nil { + return lbrynet.NewWalletError(userID, err) + } + logger.LogF(monitor.F{"user_id": userID, "sdk": addr}).Info("wallet unloaded") + return nil +} diff --git a/app/users/users_test.go b/app/wallet/wallet_test.go similarity index 50% rename from app/users/users_test.go rename to app/wallet/wallet_test.go index bdbf27af..ca96cecc 100644 --- a/app/users/users_test.go +++ b/app/wallet/wallet_test.go @@ -1,4 +1,4 @@ -package users +package wallet import ( "errors" @@ -11,9 +11,14 @@ import ( "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/config" + "github.com/lbryio/lbrytv/internal/lbrynet" + "github.com/lbryio/lbrytv/internal/responses" "github.com/lbryio/lbrytv/internal/storage" + "github.com/lbryio/lbrytv/internal/test" "github.com/lbryio/lbrytv/models" - log "github.com/sirupsen/logrus" + + jsonrpc2 "github.com/lbryio/lbry.go/v2/extras/jsonrpc" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ybbus/jsonrpc" @@ -33,7 +38,6 @@ func TestMain(m *testing.M) { defer connCleanup() code := m.Run() - os.Exit(code) } @@ -41,21 +45,30 @@ func setupDBTables() { storage.Conn.Truncate([]string{"users"}) } -func setupCleanupDummyUser(rt *sdkrouter.Router, uidParam ...int) func() { - var uid int - if len(uidParam) > 0 { - uid = uidParam[0] - } else { - uid = dummyUserID - } +func setupCleanupDummyUser(rt *sdkrouter.Router) func() { + reqChan := test.ReqChan() + ts := test.MockHTTPServer(reqChan) + go func() { + for { + req := <-reqChan + responses.AddJSONContentType(req.W) + ts.NextResponse <- fmt.Sprintf(`{ + "success": true, + "error": null, + "data": { + "user_id": %v, + "has_verified_email": true + } + }`, dummyUserID) + } + }() - ts := StartAuthenticatingAPIServer(uid) config.Override("InternalAPIHost", ts.URL) return func() { ts.Close() config.RestoreOverridden() - rt.UnloadWallet(uid) + UnloadWallet(rt.GetServer(dummyUserID).Address, dummyUserID) } } @@ -65,8 +78,7 @@ func TestWalletServiceRetrieveNewUser(t *testing.T) { defer setupCleanupDummyUser(rt)() wid := sdkrouter.WalletID(dummyUserID) - svc := NewWalletService(rt) - u, err := svc.Retrieve(Query{Token: "abc"}) + u, err := GetUserWithWallet(rt, "abc", "") require.NoError(t, err, errors.Unwrap(err)) require.NotNil(t, u) require.Equal(t, wid, u.WalletID) @@ -75,7 +87,7 @@ func TestWalletServiceRetrieveNewUser(t *testing.T) { require.NoError(t, err) assert.EqualValues(t, 1, count) - u, err = svc.Retrieve(Query{Token: "abc"}) + u, err = GetUserWithWallet(rt, "abc", "") require.NoError(t, err, errors.Unwrap(err)) require.Equal(t, wid, u.WalletID) } @@ -83,17 +95,19 @@ func TestWalletServiceRetrieveNewUser(t *testing.T) { func TestWalletServiceRetrieveNonexistentUser(t *testing.T) { setupDBTables() - ts := StartDummyAPIServer([]byte(`{ + ts := test.MockHTTPServer(nil) + defer ts.Close() + ts.NextResponse <- `{ "success": false, "error": "could not authenticate user", "data": null - }`)) - defer ts.Close() + }` + config.Override("InternalAPIHost", ts.URL) defer config.RestoreOverridden() - svc := NewWalletService(sdkrouter.New(config.GetLbrynetServers())) - u, err := svc.Retrieve(Query{Token: "non-existent-token"}) + rt := sdkrouter.New(config.GetLbrynetServers()) + u, err := GetUserWithWallet(rt, "non-existent-token", "") require.Error(t, err) require.Nil(t, u) assert.Equal(t, "cannot authenticate user with internal-apis: could not authenticate user", err.Error()) @@ -104,12 +118,11 @@ func TestWalletServiceRetrieveExistingUser(t *testing.T) { setupDBTables() defer setupCleanupDummyUser(rt)() - s := NewWalletService(rt) - u, err := s.Retrieve(Query{Token: "abc"}) + u, err := GetUserWithWallet(rt, "abc", "") require.NoError(t, err) require.NotNil(t, u) - u, err = s.Retrieve(Query{Token: "abc"}) + u, err = GetUserWithWallet(rt, "abc", "") require.NoError(t, err) assert.EqualValues(t, dummyUserID, u.ID) @@ -121,18 +134,33 @@ func TestWalletServiceRetrieveExistingUser(t *testing.T) { func TestWalletServiceRetrieveExistingUserMissingWalletID(t *testing.T) { setupDBTables() - uid := int(rand.Int31()) - ts := StartAuthenticatingAPIServer(uid) + userID := int(rand.Int31()) + + reqChan := test.ReqChan() + ts := test.MockHTTPServer(reqChan) defer ts.Close() + go func() { + req := <-reqChan + responses.AddJSONContentType(req.W) + ts.NextResponse <- fmt.Sprintf(`{ + "success": true, + "error": null, + "data": { + "user_id": %v, + "has_verified_email": true + } + }`, userID) + }() + config.Override("InternalAPIHost", ts.URL) defer config.RestoreOverridden() - s := NewWalletService(sdkrouter.New(config.GetLbrynetServers())) - u, err := s.createDBUser(uid) + rt := sdkrouter.New(config.GetLbrynetServers()) + u, err := createDBUser(userID) require.NoError(t, err) require.NotNil(t, u) - u, err = s.Retrieve(Query{Token: "abc"}) + u, err = GetUserWithWallet(rt, "abc", "") require.NoError(t, err) assert.NotEqual(t, "", u.WalletID) } @@ -140,40 +168,62 @@ func TestWalletServiceRetrieveExistingUserMissingWalletID(t *testing.T) { func TestWalletServiceRetrieveNoVerifiedEmail(t *testing.T) { setupDBTables() - ts := StartDummyAPIServer([]byte(fmt.Sprintf(userDoesntHaveVerifiedEmailResponse, 111))) + ts := test.MockHTTPServer(nil) defer ts.Close() + ts.NextResponse <- `{ + "success": true, + "error": null, + "data": { + "user_id": 111, + "has_verified_email": false + } + }` + config.Override("InternalAPIHost", ts.URL) defer config.RestoreOverridden() - svc := NewWalletService(sdkrouter.New(config.GetLbrynetServers())) - u, err := svc.Retrieve(Query{Token: "abc"}) - assert.Nil(t, u) + rt := sdkrouter.New(config.GetLbrynetServers()) + u, err := GetUserWithWallet(rt, "abc", "") assert.NoError(t, err) + assert.Nil(t, u) } func BenchmarkWalletCommands(b *testing.B) { setupDBTables() - ts := StartEasyAPIServer() + reqChan := test.ReqChan() + ts := test.MockHTTPServer(reqChan) defer ts.Close() + go func() { + req := <-reqChan + responses.AddJSONContentType(req.W) + ts.NextResponse <- fmt.Sprintf(`{ + "success": true, + "error": null, + "data": { + "user_id": %v, + "has_verified_email": true + } + }`, req.R.PostFormValue("auth_token")) + }() + config.Override("InternalAPIHost", ts.URL) defer config.RestoreOverridden() walletsNum := 60 users := make([]*models.User, walletsNum) - svc := NewWalletService(sdkrouter.New(config.GetLbrynetServers())) - sdkRouter := sdkrouter.New(config.GetLbrynetServers()) - cl := jsonrpc.NewClient(sdkRouter.RandomServer().Address) + rt := sdkrouter.New(config.GetLbrynetServers()) + cl := jsonrpc.NewClient(rt.RandomServer().Address) - svc.Logger.Disable() + logger.Disable() sdkrouter.DisableLogger() - log.SetOutput(ioutil.Discard) + logrus.SetOutput(ioutil.Discard) rand.Seed(time.Now().UnixNano()) for i := 0; i < walletsNum; i++ { uid := int(rand.Int31()) - u, err := svc.Retrieve(Query{Token: fmt.Sprintf("%v", uid)}) + u, err := GetUserWithWallet(rt, fmt.Sprintf("%v", uid), "") require.NoError(b, err, errors.Unwrap(err)) require.NotNil(b, u) users[i] = u @@ -193,3 +243,42 @@ func BenchmarkWalletCommands(b *testing.B) { b.StopTimer() } + +func TestInitializeWallet(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + userID := rand.Int() + addr := test.RandServerAddress(t) + + walletID, err := Create(addr, userID) + require.NoError(t, err) + assert.Equal(t, walletID, sdkrouter.WalletID(userID)) + + err = UnloadWallet(addr, userID) + require.NoError(t, err) + + walletID, err = Create(addr, userID) + require.NoError(t, err) + assert.Equal(t, walletID, sdkrouter.WalletID(userID)) +} + +func TestCreateWalletLoadWallet(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + userID := rand.Int() + addr := test.RandServerAddress(t) + client := jsonrpc2.NewClient(addr) + + wallet, err := createWallet(client, userID) + require.NoError(t, err) + assert.Equal(t, wallet.ID, sdkrouter.WalletID(userID)) + + wallet, err = createWallet(client, userID) + require.NotNil(t, err) + assert.True(t, errors.Is(err, lbrynet.ErrWalletExists)) + + err = UnloadWallet(addr, userID) + require.NoError(t, err) + + wallet, err = loadWallet(client, userID) + require.NoError(t, err) + assert.Equal(t, wallet.ID, sdkrouter.WalletID(userID)) +} diff --git a/internal/responses/responses.go b/internal/responses/responses.go index 38dbf0aa..218f7ada 100644 --- a/internal/responses/responses.go +++ b/internal/responses/responses.go @@ -1,32 +1,10 @@ package responses import ( - "encoding/json" "net/http" - - "github.com/ybbus/jsonrpc" ) -// PrepareJSONWriter prepares HTTP response writer for JSON content-type. -func PrepareJSONWriter(w http.ResponseWriter) { +// AddJSONContentType prepares HTTP response writer for JSON content-type. +func AddJSONContentType(w http.ResponseWriter) { w.Header().Add("content-type", "application/json; charset=utf-8") } - -// JSON is a shorthand for serializing provided structure and writing it into the provided HTTP writer as JSON. -func JSON(w http.ResponseWriter, v interface{}) { - r, _ := json.Marshal(v) - PrepareJSONWriter(w) - w.Write(r) -} - -// JSONRPCError is a shorthand for creating an RPCResponse instance with specified error message and code. -func JSONRPCError(w http.ResponseWriter, message string, code int) { - JSON(w, NewJSONRPCError(message, code)) -} - -func NewJSONRPCError(message string, code int) *jsonrpc.RPCResponse { - return &jsonrpc.RPCResponse{JSONRPC: "2.0", Error: &jsonrpc.RPCError{ - Code: code, - Message: message, - }} -} diff --git a/internal/responses/responses_test.go b/internal/responses/responses_test.go deleted file mode 100644 index 6ca7b0ee..00000000 --- a/internal/responses/responses_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package responses - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/ybbus/jsonrpc" -) - -func TestJSON(t *testing.T) { - rr := httptest.NewRecorder() - JSON(rr, map[string]int{"error_code": 625}) - assert.Equal(t, `{"error_code":625}`, rr.Body.String()) - assert.Equal(t, "application/json; charset=utf-8", rr.Header().Get("content-type")) - assert.Equal(t, http.StatusOK, rr.Result().StatusCode) -} - -func TestJSONRPCError(t *testing.T) { - var jResp jsonrpc.RPCResponse - rr := httptest.NewRecorder() - JSONRPCError(rr, "invalid input", 12345) - - err := json.Unmarshal(rr.Body.Bytes(), &jResp) - require.NoError(t, err) - - assert.Equal(t, "invalid input", jResp.Error.Message) - assert.Equal(t, 12345, jResp.Error.Code) - assert.Equal(t, "application/json; charset=utf-8", rr.Header().Get("content-type")) - assert.Equal(t, http.StatusOK, rr.Result().StatusCode) -} diff --git a/internal/test/test.go b/internal/test/test.go index 45acb16d..88b2276b 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "testing" + "github.com/lbryio/lbrytv/config" "github.com/ybbus/jsonrpc" ) @@ -18,6 +19,7 @@ type MockServer struct { type Request struct { R *http.Request + W http.ResponseWriter Body string } @@ -33,7 +35,7 @@ func MockHTTPServer(requestChan chan *Request) *MockServer { defer r.Body.Close() if requestChan != nil { data, _ := ioutil.ReadAll(r.Body) - requestChan <- &Request{r, string(data)} + requestChan <- &Request{r, w, string(data)} } fmt.Fprintf(w, <-next) })), @@ -73,3 +75,11 @@ func ResToStr(t *testing.T, res jsonrpc.RPCResponse) string { } return string(r) } + +func RandServerAddress(t *testing.T) string { + for _, addr := range config.GetLbrynetServers() { + return addr + } + t.Fatal("no lbrynet servers configured") + return "" +} diff --git a/internal/test/test_test.go b/internal/test/test_test.go index 78d35cf6..f8e441e2 100644 --- a/internal/test/test_test.go +++ b/internal/test/test_test.go @@ -1,24 +1,25 @@ package test import ( + "bytes" + "io/ioutil" "net/http" "testing" ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMockRPCServer(t *testing.T) { reqChan := ReqChan() rpcServer := MockHTTPServer(reqChan) defer rpcServer.Close() - rpcServer.NextResponse <- `{"result": {"items": [], "page": 1, "page_size": 2, "total_pages": 3}}` + rpcServer.NextResponse <- `{"result": {"items": [], "page": 1, "page_size": 2, "total_pages": 3}}` res, err := ljsonrpc.NewClient(rpcServer.URL).WalletList("", 1, 2) - if err != nil { - t.Error(err) - } + require.NoError(t, err) req := <-reqChan assert.Equal(t, req.R.Method, http.MethodPost) @@ -27,4 +28,18 @@ func TestMockRPCServer(t *testing.T) { assert.Equal(t, res.Page, uint64(1)) assert.Equal(t, res.PageSize, uint64(2)) assert.Equal(t, res.TotalPages, uint64(3)) + + rpcServer.NextResponse <- `ok` + c := &http.Client{} + r, err := http.NewRequest(http.MethodPost, rpcServer.URL, bytes.NewBuffer([]byte("hello"))) + require.NoError(t, err) + res2, err := c.Do(r) + require.NoError(t, err) + + req2 := <-reqChan + assert.Equal(t, req2.R.Method, http.MethodPost) + assert.Equal(t, req2.Body, `hello`) + body, err := ioutil.ReadAll(res2.Body) + require.NoError(t, err) + assert.Equal(t, string(body), "ok") } From 8cefa94c612b3166d849806c079abf09fee8b651 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Tue, 14 Apr 2020 17:13:50 -0400 Subject: [PATCH 06/18] tests pass now --- api/benchmarks_test.go | 36 +++--- api/routes.go | 30 ++--- api/routes_test.go | 16 ++- app/auth/auth.go | 65 +++++++++++ app/auth/auth_test.go | 117 +++++++++++++++++++ app/proxy/accounts_test.go | 100 ++++++++-------- app/proxy/caller.go | 70 +++++------ app/proxy/caller_test.go | 131 +++++++++------------ app/proxy/client_test.go | 18 +-- app/proxy/errors.go | 4 +- app/proxy/handlers.go | 56 ++++----- app/proxy/handlers_test.go | 39 ++++--- app/proxy/main_test.go | 34 +----- app/proxy/query.go | 50 +++----- app/proxy/query_test.go | 11 +- app/proxy/service.go | 45 ------- app/publish/handler_test.go | 117 ++++++++++++------- app/publish/publish.go | 136 +++++++++------------- app/publish/publish_test.go | 10 +- app/sdkrouter/middleware.go | 30 +++++ app/sdkrouter/sdkrouter.go | 2 +- app/users/authenticator.go | 80 ------------- app/users/authenticator_test.go | 79 ------------- app/users/testing.go | 16 --- app/users/testing_test.go | 46 -------- app/wallet/wallet.go | 5 +- app/wallet/wallet_test.go | 49 +++----- cmd/serve.go | 6 +- go.mod | 1 + internal/environment/environment.go | 26 ----- app/users/helpers.go => internal/ip/ip.go | 16 +-- internal/metrics/routes_test.go | 6 +- internal/monitor/monitor.go | 8 +- internal/monitor/monitor_test.go | 8 +- internal/monitor/sentry.go | 3 +- internal/responses/responses.go | 2 + internal/status/status.go | 3 +- server/server.go | 85 ++++++-------- server/server_test.go | 11 +- 39 files changed, 688 insertions(+), 879 deletions(-) create mode 100644 app/auth/auth.go create mode 100644 app/auth/auth_test.go delete mode 100644 app/proxy/service.go create mode 100644 app/sdkrouter/middleware.go delete mode 100644 app/users/authenticator.go delete mode 100644 app/users/authenticator_test.go delete mode 100644 app/users/testing.go delete mode 100644 app/users/testing_test.go delete mode 100644 internal/environment/environment.go rename app/users/helpers.go => internal/ip/ip.go (89%) diff --git a/api/benchmarks_test.go b/api/benchmarks_test.go index c5bb75d8..43409f05 100644 --- a/api/benchmarks_test.go +++ b/api/benchmarks_test.go @@ -39,7 +39,7 @@ func launchAuthenticatingAPIServer() *httptest.Server { "success": true, "error": null, "data": { - "id": %v, + "id": %s, "language": "en", "given_name": null, "family_name": null, @@ -86,30 +86,29 @@ func setupDBTables() { func BenchmarkWalletCommands(b *testing.B) { setupDBTables() - ts := launchAuthenticatingAPIServer() - defer ts.Close() - config.Override("InternalAPIHost", ts.URL) - defer config.RestoreOverridden() - - walletsNum := 30 - wallets := make([]*models.User, walletsNum) - rt := sdkrouter.New(config.GetLbrynetServers()) - wallet.DisableLogger() sdkrouter.DisableLogger() log.SetOutput(ioutil.Discard) rand.Seed(time.Now().UnixNano()) + rt := sdkrouter.New(config.GetLbrynetServers()) + + ts := launchAuthenticatingAPIServer() + defer ts.Close() + + walletsNum := 30 + wallets := make([]*models.User, walletsNum) + for i := 0; i < walletsNum; i++ { uid := int(rand.Int31()) - u, err := wallet.GetUserWithWallet(rt, fmt.Sprintf("%v", uid), "") + u, err := wallet.GetUserWithWallet(rt, ts.URL, fmt.Sprintf("%d", uid), "") require.NoError(b, err, errors.Unwrap(err)) require.NotNil(b, u) wallets[i] = u } - handler := proxy.NewRequestHandler(proxy.NewService(rt)) + handler := sdkrouter.Middleware(rt)(http.HandlerFunc(proxy.Handle)) b.SetParallelism(30) b.ResetTimer() @@ -118,17 +117,20 @@ func BenchmarkWalletCommands(b *testing.B) { for pb.Next() { u := wallets[rand.Intn(len(wallets))] - var response jsonrpc.RPCResponse q := jsonrpc.NewRequest("wallet_balance", map[string]string{"wallet_id": u.WalletID}) - qBody, _ := json.Marshal(q) - r, _ := http.NewRequest("POST", proxySuffix, bytes.NewBuffer(qBody)) - r.Header.Add("X-Lbry-Auth-Token", fmt.Sprintf("%v", u.ID)) + qBody, err := json.Marshal(q) + require.NoError(b, err) + r, err := http.NewRequest("POST", proxySuffix, bytes.NewBuffer(qBody)) + require.NoError(b, err) + r.Header.Add("X-Lbry-Auth-Token", fmt.Sprintf("%d", u.ID)) rr := httptest.NewRecorder() - handler.Handle(rr, r) + handler.ServeHTTP(rr, r) require.Equal(b, http.StatusOK, rr.Code) + + var response jsonrpc.RPCResponse json.Unmarshal(rr.Body.Bytes(), &response) require.Nil(b, response.Error) } diff --git a/api/routes.go b/api/routes.go index e6f3316e..338c9c9d 100644 --- a/api/routes.go +++ b/api/routes.go @@ -1,15 +1,14 @@ package api import ( - "context" "net/http" "strings" "time" + "github.com/lbryio/lbrytv/app/auth" "github.com/lbryio/lbrytv/app/proxy" "github.com/lbryio/lbrytv/app/publish" "github.com/lbryio/lbrytv/app/sdkrouter" - "github.com/lbryio/lbrytv/app/users" "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/metrics" "github.com/lbryio/lbrytv/internal/status" @@ -19,12 +18,11 @@ import ( ) // InstallRoutes sets up global API handlers -func InstallRoutes(proxyService *proxy.Service, r *mux.Router) { - authenticator := users.NewAuthenticator(users.NewWalletService(proxyService.SDKRouter)) - proxyHandler := proxy.NewRequestHandler(proxyService) - upHandler, err := publish.NewUploadHandler(publish.UploadOpts{ProxyService: proxyService}) - if err != nil { - panic(err) +func InstallRoutes(r *mux.Router, sdkRouter *sdkrouter.Router) { + upHandler := &publish.Handler{ + Publisher: &publish.LbrynetPublisher{Router: sdkRouter}, + UploadPath: config.GetPublishSourceDir(), + InternalAPIHost: config.GetInternalAPIHost(), } r.Use(methodTimer) @@ -34,24 +32,20 @@ func InstallRoutes(proxyService *proxy.Service, r *mux.Router) { }) v1Router := r.PathPrefix("/api/v1").Subrouter() + v1Router.Use(sdkrouter.Middleware(sdkRouter)) + retriever := auth.AllInOneRetrieverThatNeedsRefactoring(sdkRouter, config.GetInternalAPIHost()) + v1Router.Use(auth.Middleware(retriever)) v1Router.HandleFunc("/proxy", proxy.HandleCORS).Methods(http.MethodOptions) - v1Router.HandleFunc("/proxy", authenticator.Wrap(upHandler.Handle)).MatcherFunc(upHandler.CanHandle) - v1Router.HandleFunc("/proxy", proxyHandler.Handle) + v1Router.HandleFunc("/proxy", upHandler.Handle).MatcherFunc(upHandler.CanHandle) + v1Router.HandleFunc("/proxy", proxy.Handle) v1Router.HandleFunc("/metric/ui", metrics.TrackUIMetric).Methods(http.MethodPost) internalRouter := r.PathPrefix("/internal").Subrouter() internalRouter.Handle("/metrics", promhttp.Handler()) - internalRouter.HandleFunc("/status", injectSDKRouter(proxyService.SDKRouter, status.GetStatus)) + internalRouter.HandleFunc("/status", sdkrouter.AddToRequest(sdkRouter, status.GetStatus)) internalRouter.HandleFunc("/whoami", status.WhoAMI) } -// i can't tell if this is really a best practice or a hack -func injectSDKRouter(rt *sdkrouter.Router, fn http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - fn(w, r.Clone(context.WithValue(r.Context(), status.SDKRouterContextKey, rt))) - } -} - func methodTimer(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() diff --git a/api/routes_test.go b/api/routes_test.go index 3318fe2b..8f52e7d7 100644 --- a/api/routes_test.go +++ b/api/routes_test.go @@ -6,25 +6,23 @@ import ( "net/http/httptest" "testing" - "github.com/lbryio/lbrytv/app/proxy" + "github.com/gorilla/mux" "github.com/lbryio/lbrytv/app/publish" "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/config" - - "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRoutesProxy(t *testing.T) { r := mux.NewRouter() - proxy := proxy.NewService(sdkrouter.New(config.GetLbrynetServers())) + rt := sdkrouter.New(config.GetLbrynetServers()) req, err := http.NewRequest("POST", "/api/v1/proxy", bytes.NewBuffer([]byte(`{"method": "status"}`))) require.NoError(t, err) rr := httptest.NewRecorder() - InstallRoutes(proxy, r) + InstallRoutes(r, rt) r.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) @@ -33,12 +31,12 @@ func TestRoutesProxy(t *testing.T) { func TestRoutesPublish(t *testing.T) { r := mux.NewRouter() - proxy := proxy.NewService(sdkrouter.New(config.GetLbrynetServers())) + rt := sdkrouter.New(config.GetLbrynetServers()) req := publish.CreatePublishRequest(t, []byte("test file")) rr := httptest.NewRecorder() - InstallRoutes(proxy, r) + InstallRoutes(r, rt) r.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) @@ -49,13 +47,13 @@ func TestRoutesPublish(t *testing.T) { func TestRoutesOptions(t *testing.T) { r := mux.NewRouter() - proxy := proxy.NewService(sdkrouter.New(config.GetLbrynetServers())) + rt := sdkrouter.New(config.GetLbrynetServers()) req, err := http.NewRequest("OPTIONS", "/api/v1/proxy", nil) require.NoError(t, err) rr := httptest.NewRecorder() - InstallRoutes(proxy, r) + InstallRoutes(r, rt) r.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) assert.Equal(t, "7200", rr.Result().Header.Get("Access-Control-Max-Age")) diff --git a/app/auth/auth.go b/app/auth/auth.go new file mode 100644 index 00000000..955deef2 --- /dev/null +++ b/app/auth/auth.go @@ -0,0 +1,65 @@ +package auth + +import ( + "context" + "net/http" + + "github.com/lbryio/lbrytv/app/sdkrouter" + "github.com/lbryio/lbrytv/app/wallet" + "github.com/lbryio/lbrytv/internal/ip" + "github.com/lbryio/lbrytv/internal/monitor" + "github.com/lbryio/lbrytv/models" + + "github.com/gorilla/mux" +) + +var logger = monitor.NewModuleLogger("auth") + +const ContextKey = "user" + +type Result struct { + User *models.User + Err error +} + +func (r *Result) AuthAttempted() bool { return r.User != nil || r.Err != nil } +func (r *Result) AuthFailed() bool { return r.Err != nil } +func (r *Result) Authenticated() bool { return r.User != nil } + +func FromRequest(r *http.Request) *Result { + v := r.Context().Value(ContextKey) + if v == nil { + panic("Auth middleware was not applied") + } + return v.(*Result) +} + +// Retriever gets a user by hitting internal-api with the provided auth token +// and matching the response to a local user. +// NOTE: The retrieved user must come with a wallet that's created and ready to use. +type Retriever func(token, metaRemoteIP string) (*models.User, error) + +func AllInOneRetrieverThatNeedsRefactoring(rt *sdkrouter.Router, internalAPIHost string) Retriever { + return func(token, metaRemoteIP string) (user *models.User, err error) { + return wallet.GetUserWithWallet(rt, internalAPIHost, token, metaRemoteIP) + } +} + +func Middleware(retriever Retriever) mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ar := &Result{} + if token, ok := r.Header[wallet.TokenHeader]; ok { + addr := ip.AddressForRequest(r) + user, err := retriever(token[0], addr) + if err != nil { + logger.LogF(monitor.F{"ip": addr}).Debugf("failed to authenticate user") + ar.Err = err + } else { + ar.User = user + } + } + next.ServeHTTP(w, r.Clone(context.WithValue(r.Context(), ContextKey, ar))) + }) + } +} diff --git a/app/auth/auth_test.go b/app/auth/auth_test.go new file mode 100644 index 00000000..c8da91e6 --- /dev/null +++ b/app/auth/auth_test.go @@ -0,0 +1,117 @@ +package auth + +import ( + "bytes" + "context" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/lbryio/lbrytv/app/wallet" + "github.com/lbryio/lbrytv/models" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMiddleware(t *testing.T) { + r, err := http.NewRequest("GET", "/api/proxy", nil) + require.NoError(t, err) + r.Header.Set(wallet.TokenHeader, "secret-token") + r.Header.Set("X-Forwarded-For", "8.8.8.8") + + var receivedRemoteIP string + retriever := func(token, ip string) (*models.User, error) { + receivedRemoteIP = ip + if token == "secret-token" { + return &models.User{ID: 16595}, nil + } + return nil, errors.New("error") + } + + rr := httptest.NewRecorder() + Middleware(retriever)(http.HandlerFunc(authChecker)).ServeHTTP(rr, r) + + response := rr.Result() + body, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + assert.Equal(t, "16595", string(body)) + assert.Equal(t, "8.8.8.8", receivedRemoteIP) +} + +func TestMiddlewareAuthFailure(t *testing.T) { + r, err := http.NewRequest("GET", "/api/proxy", nil) + require.NoError(t, err) + r.Header.Set(wallet.TokenHeader, "wrong-token") + rr := httptest.NewRecorder() + + retriever := func(token, ip string) (*models.User, error) { + if token == "good-token" { + return &models.User{ID: 1}, nil + } + return nil, errors.New("incorrect token") + } + Middleware(retriever)(http.HandlerFunc(authChecker)).ServeHTTP(rr, r) + + response := rr.Result() + body, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + assert.Equal(t, "incorrect token", string(body)) + assert.Equal(t, http.StatusForbidden, response.StatusCode) +} + +func TestMiddlewareNoAuth(t *testing.T) { + r, err := http.NewRequest("GET", "/api/proxy", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + + retriever := func(token, ip string) (*models.User, error) { + if token == "good-token" { + return &models.User{ID: 1}, nil + } + return nil, errors.New("incorrect token") + } + Middleware(retriever)(http.HandlerFunc(authChecker)).ServeHTTP(rr, r) + + response := rr.Result() + body, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, response.StatusCode) + assert.Equal(t, "no auth info", string(body)) +} + +func TestFromRequestSuccess(t *testing.T) { + expected := &Result{Err: errors.New("a test")} + ctx := context.WithValue(context.Background(), ContextKey, expected) + + r, err := http.NewRequestWithContext(ctx, http.MethodPost, "", &bytes.Buffer{}) + require.NoError(t, err) + + results := FromRequest(r) + assert.NotNil(t, results) + assert.Equal(t, expected.User, results.User) + assert.Equal(t, expected.Err.Error(), results.Err.Error()) +} + +func TestFromRequestFail(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "", &bytes.Buffer{}) + require.NoError(t, err) + assert.Panics(t, func() { FromRequest(r) }) +} + +func authChecker(w http.ResponseWriter, r *http.Request) { + result := FromRequest(r) + if result.Authenticated() { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte(fmt.Sprintf("%d", result.User.ID))) + } else if result.AuthFailed() { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(result.Err.Error())) + } else { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("no auth info")) + } +} diff --git a/app/proxy/accounts_test.go b/app/proxy/accounts_test.go index ee60b6a3..fa1efc53 100644 --- a/app/proxy/accounts_test.go +++ b/app/proxy/accounts_test.go @@ -7,7 +7,12 @@ import ( "net/http/httptest" "testing" + "github.com/lbryio/lbrytv/app/auth" + "github.com/lbryio/lbrytv/app/sdkrouter" + "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/config" + "github.com/lbryio/lbrytv/internal/test" + "github.com/lbryio/lbrytv/models" ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" @@ -18,62 +23,60 @@ import ( func TestWithWrongAuthToken(t *testing.T) { testFuncSetup() - defer testFuncTeardown() - var ( - q *jsonrpc.RPCRequest - qBody []byte - response jsonrpc.RPCResponse - ) - - ts := launchDummyAPIServer([]byte(`{ + ts := test.MockHTTPServer(nil) + defer ts.Close() + ts.NextResponse <- `{ "success": false, "error": "could not authenticate user", "data": null - }`)) - defer ts.Close() - config.Override("InternalAPIHost", ts.URL) - defer config.RestoreOverridden() + }` - q = jsonrpc.NewRequest("account_list") - qBody, _ = json.Marshal(q) - r, _ := http.NewRequest("POST", proxySuffix, bytes.NewBuffer(qBody)) + q := jsonrpc.NewRequest("account_list") + qBody, err := json.Marshal(q) + require.NoError(t, err) + r, err := http.NewRequest("POST", "/api/v1/proxy", bytes.NewBuffer(qBody)) + require.NoError(t, err) r.Header.Add("X-Lbry-Auth-Token", "xXxXxXx") rr := httptest.NewRecorder() - handler := NewRequestHandler(svc) - handler.Handle(rr, r) + + rt := sdkrouter.New(config.GetLbrynetServers()) + retriever := func(token, ip string) (*models.User, error) { + return wallet.GetUserWithWallet(rt, ts.URL, token, "") + } + + handler := sdkrouter.Middleware(rt)(auth.Middleware(retriever)(http.HandlerFunc(Handle))) + handler.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) - err := json.Unmarshal(rr.Body.Bytes(), &response) + var response jsonrpc.RPCResponse + err = json.Unmarshal(rr.Body.Bytes(), &response) require.NoError(t, err) assert.Equal(t, "cannot authenticate user with internal-apis: could not authenticate user", response.Error.Message) } func TestWithoutToken(t *testing.T) { testFuncSetup() - defer testFuncTeardown() - - var ( - q *jsonrpc.RPCRequest - qBody []byte - response jsonrpc.RPCResponse - statusResponse ljsonrpc.StatusResponse - ) - q = jsonrpc.NewRequest("status") - qBody, _ = json.Marshal(q) - r, _ := http.NewRequest("POST", proxySuffix, bytes.NewBuffer(qBody)) + q, err := json.Marshal(jsonrpc.NewRequest("status")) + require.NoError(t, err) + r, err := http.NewRequest("POST", "/api/v1/proxy", bytes.NewBuffer(q)) + require.NoError(t, err) rr := httptest.NewRecorder() - handler := NewRequestHandler(svc) - handler.Handle(rr, r) - require.Equal(t, http.StatusOK, rr.Code) - err := json.Unmarshal(rr.Body.Bytes(), &response) + rt := sdkrouter.New(config.GetLbrynetServers()) + handler := sdkrouter.Middleware(rt)(http.HandlerFunc(Handle)) + handler.ServeHTTP(rr, r) + + require.Equal(t, http.StatusOK, rr.Code) + var response jsonrpc.RPCResponse + err = json.Unmarshal(rr.Body.Bytes(), &response) require.NoError(t, err) require.Nil(t, response.Error) + var statusResponse ljsonrpc.StatusResponse err = ljsonrpc.Decode(response.Result, &statusResponse) require.NoError(t, err) assert.True(t, statusResponse.IsRunning) @@ -81,25 +84,26 @@ func TestWithoutToken(t *testing.T) { func TestAccountSpecificWithoutToken(t *testing.T) { testFuncSetup() - defer testFuncTeardown() - var ( - q *jsonrpc.RPCRequest - qBody []byte - response jsonrpc.RPCResponse - ) - - q = jsonrpc.NewRequest("account_list") - qBody, _ = json.Marshal(q) - r, _ := http.NewRequest("POST", proxySuffix, bytes.NewBuffer(qBody)) + q := jsonrpc.NewRequest("account_list") + qBody, err := json.Marshal(q) + require.NoError(t, err) + r, err := http.NewRequest("POST", "/api/v1/proxy", bytes.NewBuffer(qBody)) + require.NoError(t, err) rr := httptest.NewRecorder() - handler := NewRequestHandler(svc) - handler.Handle(rr, r) - require.Equal(t, http.StatusOK, rr.Code) - err := json.Unmarshal(rr.Body.Bytes(), &response) + rt := sdkrouter.New(config.GetLbrynetServers()) + retriever := func(token, ip string) (*models.User, error) { + return nil, nil + } + handler := sdkrouter.Middleware(rt)(auth.Middleware(retriever)(http.HandlerFunc(Handle))) + handler.ServeHTTP(rr, r) + + require.Equal(t, http.StatusOK, rr.Code) + var response jsonrpc.RPCResponse + err = json.Unmarshal(rr.Body.Bytes(), &response) require.NoError(t, err) require.NotNil(t, response.Error) - require.Equal(t, "account identifier required", response.Error.Message) + require.Equal(t, "authentication required", response.Error.Message) } diff --git a/app/proxy/caller.go b/app/proxy/caller.go index f37ee1da..5cb84ba5 100644 --- a/app/proxy/caller.go +++ b/app/proxy/caller.go @@ -15,8 +15,11 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "time" + "github.com/davecgh/go-spew/spew" + "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/internal/lbrynet" "github.com/lbryio/lbrytv/internal/metrics" "github.com/lbryio/lbrytv/internal/monitor" @@ -41,51 +44,60 @@ type Caller struct { Preprocessor func(q *Query) client jsonrpc.RPCClient - walletID string + userID int endpoint string } -func NewCaller(endpoint, walletID string) *Caller { +func NewCaller(endpoint string, userID int) *Caller { return &Caller{ - client: jsonrpc.NewClient(endpoint), + client: jsonrpc.NewClientWithOpts(endpoint, &jsonrpc.RPCClientOpts{ + HTTPClient: &http.Client{Timeout: 30 * time.Second}, + }), endpoint: endpoint, - walletID: walletID, + userID: userID, } } +func (c *Caller) CallRaw(rawQuery []byte) []byte { + var req jsonrpc.RPCRequest + err := json.Unmarshal(rawQuery, &req) + if err != nil { + return marshalError(NewJSONParseError(err)) + } + return c.Call(&req) +} + // Call method processes a raw query received from JSON-RPC client and forwards it to LbrynetServer. // It returns a response that is ready to be sent back to the JSON-RPC client as is. -func (c *Caller) Call(rawQuery []byte) []byte { - r, err := c.call(rawQuery) +func (c *Caller) Call(req *jsonrpc.RPCRequest) []byte { + r, err := c.call(req) if err != nil { - if !isJSONParseError(err) { - monitor.CaptureException(err, map[string]string{"query": string(rawQuery), "response": fmt.Sprintf("%v", r)}) - Logger.Errorf("error calling lbrynet: %v, query: %s", err, rawQuery) - } + monitor.CaptureException(err, map[string]string{"req": spew.Sdump(req), "response": fmt.Sprintf("%v", r)}) + Logger.Errorf("error calling lbrynet: %v, request: %s", err, spew.Sdump(req)) return marshalError(err) } - serialized, err := marshalResponse(r) + serialized, err := json.MarshalIndent(r, "", " ") if err != nil { monitor.CaptureException(err) Logger.Errorf("error marshaling response: %v", err) - return marshalError(err) + return marshalError(NewInternalError(err)) } return serialized } -func (c *Caller) call(rawQuery []byte) (*jsonrpc.RPCResponse, error) { - q, err := NewQuery(rawQuery) +func (c *Caller) call(req *jsonrpc.RPCRequest) (*jsonrpc.RPCResponse, error) { + q, err := NewQuery(req) if err != nil { return nil, err } - if c.walletID != "" { - q.SetWalletID(c.walletID) + if c.userID != 0 { + q.WalletID = sdkrouter.WalletID(c.userID) } - // Check for account identifier (wallet ID) for account-specific methods happens here + // Check for auth for account-specific methods happens here if err := q.validate(); err != nil { return nil, err } @@ -147,19 +159,19 @@ func (c *Caller) callQueryWithRetry(q *Query) (*jsonrpc.RPCResponse, error) { time.Sleep(walletLoadRetryWait) // Using LBRY JSON-RPC client here for easier request/response processing client := ljsonrpc.NewClient(c.endpoint) - _, err := client.WalletAdd(c.walletID) + _, err := client.WalletAdd(sdkrouter.WalletID(c.userID)) // Alert sentry on the last failed wallet load attempt if err != nil && i >= walletLoadRetries-1 { errMsg := "gave up on manually adding a wallet: %v" Logger.Logger().WithFields(logrus.Fields{ - "wallet_id": c.walletID, - "endpoint": c.endpoint, + "user_id": c.userID, + "endpoint": c.endpoint, }).Errorf(errMsg, err) monitor.CaptureException( fmt.Errorf(errMsg, err), map[string]string{ - "wallet_id": c.walletID, - "endpoint": c.endpoint, - "retries": fmt.Sprintf("%v", i), + "user_id": fmt.Sprintf("%d", c.userID), + "endpoint": c.endpoint, + "retries": fmt.Sprintf("%d", i), }) } } else if isErrWalletAlreadyLoaded(r) { @@ -170,23 +182,15 @@ func (c *Caller) callQueryWithRetry(q *Query) (*jsonrpc.RPCResponse, error) { } if (r != nil && r.Error != nil) || err != nil { - Logger.LogFailedQuery(q.Method(), c.endpoint, c.walletID, duration, q.Params(), r.Error) + Logger.LogFailedQuery(q.Method(), c.endpoint, c.userID, duration, q.Params(), r.Error) failureMetrics.Observe(duration) } else { - Logger.LogSuccessfulQuery(q.Method(), c.endpoint, c.walletID, duration, q.Params(), r) + Logger.LogSuccessfulQuery(q.Method(), c.endpoint, c.userID, duration, q.Params(), r) } return r, err } -func marshalResponse(r *jsonrpc.RPCResponse) ([]byte, error) { - serialized, err := json.MarshalIndent(r, "", " ") - if err != nil { - return nil, NewInternalError(err) - } - return serialized, nil -} - func marshalError(err error) []byte { var rpcErr RPCError if errors.As(err, &rpcErr) { diff --git a/app/proxy/caller_test.go b/app/proxy/caller_test.go index 4e1ff1b2..fb2e5ad1 100644 --- a/app/proxy/caller_test.go +++ b/app/proxy/caller_test.go @@ -2,7 +2,6 @@ package proxy import ( "encoding/json" - "fmt" "math/rand" "net/http" "net/http/httptest" @@ -23,91 +22,63 @@ import ( "github.com/ybbus/jsonrpc" ) -func newRawRequest(t *testing.T, method string, params interface{}) []byte { - var ( - body []byte - err error - ) - if params != nil { - body, err = json.Marshal(jsonrpc.NewRequest(method, params)) - } else { - body, err = json.Marshal(jsonrpc.NewRequest(method)) - } - if err != nil { - t.Fatal(err) - } - return body -} - -func parseRawResponse(t *testing.T, rawCallReponse []byte, v interface{}) { - assert.NotNil(t, rawCallReponse) +func parseRawResponse(t *testing.T, rawCallResponse []byte, v interface{}) { + assert.NotNil(t, rawCallResponse) var res jsonrpc.RPCResponse - err := json.Unmarshal(rawCallReponse, &res) + err := json.Unmarshal(rawCallResponse, &res) require.NoError(t, err) err = res.GetObject(v) require.NoError(t, err) } -func TestNewQuery(t *testing.T) { - for _, rawQ := range []string{``, ` `, `{}`, `[]`, `[{}]`, `[""]`, `""`, `" "`, `{"method": " "}`} { - t.Run(rawQ, func(t *testing.T) { - q, err := NewQuery([]byte(rawQ)) - assert.Nil(t, q) - assert.Error(t, err) - }) - } - -} - func TestNewCaller(t *testing.T) { servers := map[string]string{ "first": "http://lbrynet1", "second": "http://lbrynet2", } - svc := NewService(sdkrouter.New(servers)) - sList := svc.SDKRouter.GetAll() + rt := sdkrouter.New(servers) + sList := rt.GetAll() rand.Seed(time.Now().UnixNano()) for i := 1; i <= 100; i++ { id := rand.Intn(10^6-10^3) + 10 ^ 3 - wc := svc.NewCaller(fmt.Sprintf("wallet.%v", id)) + wc := NewCaller(rt.GetServer(id).Address, id) lastDigit := id % 10 assert.Equal(t, sList[lastDigit%len(sList)].Address, wc.endpoint) } } -func TestCallerСall(t *testing.T) { - c := NewService(sdkrouter.New(config.GetLbrynetServers())).NewCaller("abc") - for _, rawQ := range []string{``, ` `, `{}`, `[]`, `[{}]`, `[""]`, `""`, `" "`, `{"method": " "}`} { +func TestCallerCallRaw(t *testing.T) { + c := NewCaller(test.RandServerAddress(t), 0) + for _, rawQ := range []string{``, ` `, `[]`, `[{}]`, `[""]`, `""`, `" "`} { t.Run(rawQ, func(t *testing.T) { - r := c.Call([]byte(rawQ)) - assert.Contains(t, string(r), `"code": -32700`) + r := c.CallRaw([]byte(rawQ)) + assert.Contains(t, string(r), `"code": -32700`, `raw query: `+rawQ) + }) + } + for _, rawQ := range []string{`{}`, `{"method": " "}`} { + t.Run(rawQ, func(t *testing.T) { + r := c.CallRaw([]byte(rawQ)) + assert.Contains(t, string(r), `"code": -32080`, `raw query: `+rawQ) }) } - -} - -func TestCallerSetWalletID(t *testing.T) { - svc := NewService(sdkrouter.New(config.GetLbrynetServers())) - c := svc.NewCaller("abc") - assert.Equal(t, "abc", c.walletID) } func TestCallerCallResolve(t *testing.T) { - svc := NewService(sdkrouter.New(config.GetLbrynetServers())) + rt := sdkrouter.New(config.GetLbrynetServers()) resolvedURL := "what#6769855a9aa43b67086f9ff3c1a5bacb5698a27a" resolvedClaimID := "6769855a9aa43b67086f9ff3c1a5bacb5698a27a" - request := newRawRequest(t, "resolve", map[string]string{"urls": resolvedURL}) - rawCallReponse := svc.NewCaller("").Call(request) + request := jsonrpc.NewRequest("resolve", map[string]interface{}{"urls": resolvedURL}) + rawCallResponse := NewCaller(rt.RandomServer().Address, 0).Call(request) var errorResponse jsonrpc.RPCResponse - err := json.Unmarshal(rawCallReponse, &errorResponse) + err := json.Unmarshal(rawCallResponse, &errorResponse) require.NoError(t, err) require.Nil(t, errorResponse.Error) var resolveResponse ljsonrpc.ResolveResponse - parseRawResponse(t, rawCallReponse, &resolveResponse) + parseRawResponse(t, rawCallResponse, &resolveResponse) assert.Equal(t, resolvedClaimID, resolveResponse[resolvedURL].ClaimID) } @@ -115,23 +86,25 @@ func TestCallerCallWalletBalance(t *testing.T) { rand.Seed(time.Now().UnixNano()) dummyUserID := rand.Intn(10^6-10^3) + 10 ^ 3 rt := sdkrouter.New(config.GetLbrynetServers()) - svc := NewService(rt) - request := newRawRequest(t, "wallet_balance", nil) + request := jsonrpc.NewRequest("wallet_balance") - result := svc.NewCaller("").Call(request) - assert.Contains(t, string(result), `"message": "account identifier required"`) + result := NewCaller(rt.RandomServer().Address, 0).Call(request) + assert.Contains(t, string(result), `"message": "authentication required"`) - walletID, err := wallet.Create(test.RandServerAddress(t), dummyUserID) + addr := test.RandServerAddress(t) + walletID, err := wallet.Create(addr, dummyUserID) require.NoError(t, err) hook := logrusTest.NewLocal(Logger.Logger()) - result = svc.NewCaller(walletID).Call(request) + result = NewCaller(addr, dummyUserID).Call(request) - var accountBalanceResponse ljsonrpc.AccountBalanceResponse + var accountBalanceResponse struct { + Available string `json:"available"` + } parseRawResponse(t, result, &accountBalanceResponse) - assert.EqualValues(t, "0", fmt.Sprintf("%v", accountBalanceResponse.Available)) - assert.Equal(t, map[string]interface{}{"wallet_id": fmt.Sprintf("%v", walletID)}, hook.LastEntry().Data["params"]) + assert.EqualValues(t, "0.0", accountBalanceResponse.Available) + assert.Equal(t, map[string]interface{}{"wallet_id": walletID}, hook.LastEntry().Data["params"]) assert.Equal(t, "wallet_balance", hook.LastEntry().Data["method"]) } @@ -139,7 +112,7 @@ func TestCallerCallRelaxedMethods(t *testing.T) { reqChan := test.ReqChan() srv := test.MockHTTPServer(reqChan) defer srv.Close() - caller := NewCaller(srv.URL, "") + caller := NewCaller(srv.URL, 0) for _, m := range relaxedMethods { t.Run(m, func(t *testing.T) { @@ -147,7 +120,7 @@ func TestCallerCallRelaxedMethods(t *testing.T) { return } srv.NextResponse <- "" - caller.Call(newRawRequest(t, m, nil)) + caller.Call(jsonrpc.NewRequest(m)) receivedRequest := <-reqChan expectedRequest := test.ReqToStr(t, jsonrpc.RPCRequest{ Method: m, @@ -160,29 +133,29 @@ func TestCallerCallRelaxedMethods(t *testing.T) { } func TestCallerCallNonRelaxedMethods(t *testing.T) { - caller := NewCaller("", "") + caller := NewCaller("whatever", 0) for _, m := range walletSpecificMethods { - result := caller.Call(newRawRequest(t, m, nil)) - assert.Contains(t, string(result), `"message": "account identifier required"`) + result := caller.Call(jsonrpc.NewRequest(m)) + assert.Contains(t, string(result), `"message": "authentication required"`) } } func TestCallerCallForbiddenMethod(t *testing.T) { - caller := NewCaller("", "") - result := caller.Call(newRawRequest(t, "stop", nil)) + caller := NewCaller(test.RandServerAddress(t), 0) + result := caller.Call(jsonrpc.NewRequest("stop")) assert.Contains(t, string(result), `"message": "forbidden method"`) } func TestCallerCallAttachesWalletID(t *testing.T) { rand.Seed(time.Now().UnixNano()) - dummyWalletID := "abc123321" + dummyUserID := 123321 reqChan := test.ReqChan() srv := test.MockHTTPServer(reqChan) defer srv.Close() srv.NextResponse <- "" - caller := NewCaller(srv.URL, dummyWalletID) - caller.Call(newRawRequest(t, "channel_create", map[string]string{"name": "test", "bid": "0.1"})) + caller := NewCaller(srv.URL, dummyUserID) + caller.Call(jsonrpc.NewRequest("channel_create", map[string]interface{}{"name": "test", "bid": "0.1"})) receivedRequest := <-reqChan expectedRequest := test.ReqToStr(t, jsonrpc.RPCRequest{ @@ -190,7 +163,7 @@ func TestCallerCallAttachesWalletID(t *testing.T) { Params: map[string]interface{}{ "name": "test", "bid": "0.1", - "wallet_id": dummyWalletID, + "wallet_id": sdkrouter.WalletID(dummyUserID), }, JSONRPC: "2.0", }) @@ -202,7 +175,7 @@ func TestCallerSetPreprocessor(t *testing.T) { srv := test.MockHTTPServer(reqChan) defer srv.Close() - c := NewCaller(srv.URL, "") + c := NewCaller(srv.URL, 0) c.Preprocessor = func(q *Query) { params := q.ParamsAsMap() @@ -216,7 +189,7 @@ func TestCallerSetPreprocessor(t *testing.T) { srv.NextResponse <- "" - c.Call(newRawRequest(t, relaxedMethods[0], nil)) + c.Call(jsonrpc.NewRequest(relaxedMethods[0])) req := <-reqChan lastRequest := test.StrToReq(t, req.Body) @@ -256,9 +229,9 @@ func TestCallerCallSDKError(t *testing.T) { "id": 0 }` - c := NewCaller(srv.URL, "") + c := NewCaller(srv.URL, 0) hook := logrusTest.NewLocal(Logger.Logger()) - response := c.Call(newRawRequest(t, "resolve", map[string]string{"urls": "what"})) + response := c.Call(jsonrpc.NewRequest("resolve", map[string]interface{}{"urls": "what"})) var rpcResponse jsonrpc.RPCResponse json.Unmarshal(response, &rpcResponse) assert.Equal(t, rpcResponse.Error.Code, -32500) @@ -271,8 +244,8 @@ func TestCallerCallClientJSONError(t *testing.T) { responses.AddJSONContentType(w) w.Write([]byte(`{"method":"version}`)) })) - c := NewCaller(ts.URL, "") - response := c.Call([]byte(`{"method":"version}`)) + c := NewCaller(ts.URL, 0) + response := c.CallRaw([]byte(`{"method":"version}`)) var rpcResponse jsonrpc.RPCResponse json.Unmarshal(response, &rpcResponse) assert.Equal(t, "2.0", rpcResponse.JSONRPC) @@ -281,8 +254,8 @@ func TestCallerCallClientJSONError(t *testing.T) { } func TestSDKMethodStatus(t *testing.T) { - c := NewService(sdkrouter.New(config.GetLbrynetServers())).NewCaller("") - callResult := c.Call(newRawRequest(t, "status", nil)) + c := NewCaller(test.RandServerAddress(t), 0) + callResult := c.Call(jsonrpc.NewRequest("status")) var rpcResponse jsonrpc.RPCResponse json.Unmarshal(callResult, &rpcResponse) assert.Equal(t, diff --git a/app/proxy/client_test.go b/app/proxy/client_test.go index 6301c1bb..3de5a722 100644 --- a/app/proxy/client_test.go +++ b/app/proxy/client_test.go @@ -23,11 +23,11 @@ func TestClientCallDoesReloadWallet(t *testing.T) { err = wallet.UnloadWallet(addr, dummyUserID) require.NoError(t, err) - q, err := NewQuery(newRawRequest(t, "wallet_balance", nil)) + q, err := NewQuery(jsonrpc.NewRequest("wallet_balance")) require.NoError(t, err) - q.SetWalletID(walletID) + q.WalletID = walletID - c := NewCaller(addr, walletID) + c := NewCaller(addr, dummyUserID) r, err := c.callQueryWithRetry(q) // err = json.Unmarshal(result, response) require.NoError(t, err) @@ -43,10 +43,10 @@ func TestClientCallDoesNotReloadWalletAfterOtherErrors(t *testing.T) { srv := test.MockHTTPServer(nil) defer srv.Close() - c := NewCaller(srv.URL, "") - q, err := NewQuery(newRawRequest(t, "wallet_balance", nil)) + c := NewCaller(srv.URL, 0) + q, err := NewQuery(jsonrpc.NewRequest("wallet_balance")) require.NoError(t, err) - q.SetWalletID(walletID) + q.WalletID = walletID go func() { srv.NextResponse <- test.ResToStr(t, jsonrpc.RPCResponse{ @@ -76,10 +76,10 @@ func TestClientCallDoesNotReloadWalletIfAlreadyLoaded(t *testing.T) { srv := test.MockHTTPServer(nil) defer srv.Close() - c := NewCaller(srv.URL, "") - q, err := NewQuery(newRawRequest(t, "wallet_balance", nil)) + c := NewCaller(srv.URL, 0) + q, err := NewQuery(jsonrpc.NewRequest("wallet_balance")) require.NoError(t, err) - q.SetWalletID(walletID) + q.WalletID = walletID go func() { srv.NextResponse <- test.ResToStr(t, jsonrpc.RPCResponse{ diff --git a/app/proxy/errors.go b/app/proxy/errors.go index 3686a8e0..65ef5a95 100644 --- a/app/proxy/errors.go +++ b/app/proxy/errors.go @@ -11,7 +11,7 @@ const ( rpcErrorCodeInternal int = -32080 // general errors that originate inside the proxy module rpcErrorCodeSDK int = -32603 // otherwise-unspecified errors from the SDK rpcErrorCodeAuthRequired int = -32084 // auth info is required but is not provided - rpcErrorCodeUnauthorized int = -32085 // auth info is provided but is not found in the database + rpcErrorCodeForbidden int = -32085 // auth info is provided but is not found in the database rpcErrorCodeJSONParse int = -32700 // invalid JSON was received by the server rpcErrorCodeInvalidParams int = -32602 // error in params that the client provided rpcErrorCodeMethodNotAllowed int = -32601 // the requested method is not allowed to be called @@ -45,7 +45,7 @@ func NewJSONParseError(e error) RPCError { return RPCError{e, rpcErrorCod func NewMethodNotAllowedError(e error) RPCError { return RPCError{e, rpcErrorCodeMethodNotAllowed} } func NewInvalidParamsError(e error) RPCError { return RPCError{e, rpcErrorCodeInvalidParams} } func NewSDKError(e error) RPCError { return RPCError{e, rpcErrorCodeSDK} } -func NewUnauthorizedError(e error) RPCError { return RPCError{e, rpcErrorCodeUnauthorized} } +func NewForbiddenError(e error) RPCError { return RPCError{e, rpcErrorCodeForbidden} } func NewAuthRequiredError(e error) RPCError { return RPCError{e, rpcErrorCodeAuthRequired} } func isJSONParseError(err error) bool { diff --git a/app/proxy/handlers.go b/app/proxy/handlers.go index f15b7128..58e2c076 100644 --- a/app/proxy/handlers.go +++ b/app/proxy/handlers.go @@ -1,29 +1,22 @@ package proxy import ( + "encoding/json" + "errors" "io/ioutil" "net/http" - "github.com/lbryio/lbrytv/app/users" + "github.com/lbryio/lbrytv/app/auth" + "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/internal/responses" + "github.com/ybbus/jsonrpc" ) -var proxyHandlerLogger = monitor.NewModuleLogger("proxy_handlers") - -// RequestHandler is a wrapper for passing proxy.Service instance to proxy HTTP handler. -type RequestHandler struct { - *Service -} - -// NewRequestHandler initializes request handler with a provided Proxy Service instance -func NewRequestHandler(svc *Service) *RequestHandler { - return &RequestHandler{Service: svc} -} - // Handle forwards client JSON-RPC request to proxy. -func (rh *RequestHandler) Handle(w http.ResponseWriter, r *http.Request) { +func Handle(w http.ResponseWriter, r *http.Request) { + var proxyHandlerLogger = monitor.NewModuleLogger("proxy_handlers") if r.Body == nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("empty request body")) @@ -39,26 +32,33 @@ func (rh *RequestHandler) Handle(w http.ResponseWriter, r *http.Request) { return } - var walletID string + // We're in RPC-response-land from here on down + responses.AddJSONContentType(w) - q, err := NewQuery(body) - if err != nil || !methodInList(q.Method(), relaxedMethods) { - auth := users.NewAuthenticator(rh.SDKRouter) + var req jsonrpc.RPCRequest + err = json.Unmarshal(body, &req) + if err != nil { + w.Write(NewJSONParseError(err).JSON()) + return + } - walletID, err = auth.GetWalletID(r) - if err != nil { - responses.AddJSONContentType(w) - w.Write(marshalError(err)) - monitor.CaptureRequestError(err, r, w) + var userID int + if MethodNeedsAuth(req.Method) { + authResult := auth.FromRequest(r) + if !authResult.AuthAttempted() { + w.Write(NewAuthRequiredError(errors.New(responses.AuthRequiredErrorMessage)).JSON()) + return + } + if !authResult.Authenticated() { + w.Write(NewForbiddenError(authResult.Err).JSON()) return } + userID = authResult.User.ID } - c := rh.NewCaller(walletID) - - rawCallReponse := c.Call(body) - responses.AddJSONContentType(w) - w.Write(rawCallReponse) + rt := sdkrouter.FromRequest(r) + c := NewCaller(rt.GetServer(userID).Address, userID) + w.Write(c.Call(&req)) } // HandleCORS returns necessary CORS headers for pre-flight requests to proxy API diff --git a/app/proxy/handlers_test.go b/app/proxy/handlers_test.go index a01f1c9e..40b1455b 100644 --- a/app/proxy/handlers_test.go +++ b/app/proxy/handlers_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "testing" + "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/config" @@ -27,50 +28,58 @@ func TestProxyOptions(t *testing.T) { } func TestProxyNilQuery(t *testing.T) { - r, _ := http.NewRequest("POST", "", nil) + r, err := http.NewRequest("POST", "", nil) + require.NoError(t, err) rr := httptest.NewRecorder() - handler := NewRequestHandler(svc) - handler.Handle(rr, r) + rt := sdkrouter.New(config.GetLbrynetServers()) + handler := sdkrouter.Middleware(rt)(http.HandlerFunc(Handle)) + handler.ServeHTTP(rr, r) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Equal(t, "empty request body", rr.Body.String()) } func TestProxyInvalidQuery(t *testing.T) { - var parsedResponse jsonrpc.RPCResponse - r, _ := http.NewRequest("POST", "", bytes.NewBuffer([]byte("yo"))) + r, err := http.NewRequest("POST", "", bytes.NewBuffer([]byte("yo"))) + require.NoError(t, err) rr := httptest.NewRecorder() - handler := NewRequestHandler(svc) - handler.Handle(rr, r) + rt := sdkrouter.New(config.GetLbrynetServers()) + handler := sdkrouter.Middleware(rt)(http.HandlerFunc(Handle)) + handler.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) - err := json.Unmarshal(rr.Body.Bytes(), &parsedResponse) + var parsedResponse jsonrpc.RPCResponse + err = json.Unmarshal(rr.Body.Bytes(), &parsedResponse) require.NoError(t, err) assert.Contains(t, parsedResponse.Error.Message, "invalid character 'y' looking for beginning of value") } func TestProxyDontAuthRelaxedMethods(t *testing.T) { - var parsedResponse jsonrpc.RPCResponse var apiCalls int - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { apiCalls++ })) config.Override("InternalAPIHost", ts.URL) - r, _ := http.NewRequest("POST", "", bytes.NewBuffer([]byte(newRawRequest(t, "resolve", map[string]string{"urls": "what"})))) + rawReq := jsonrpc.NewRequest("resolve", map[string]string{"urls": "what"}) + raw, err := json.Marshal(rawReq) + require.NoError(t, err) + + r, err := http.NewRequest("POST", "", bytes.NewBuffer(raw)) + require.NoError(t, err) r.Header.Set(wallet.TokenHeader, "abc") rr := httptest.NewRecorder() - handler := NewRequestHandler(svc) - handler.Handle(rr, r) + rt := sdkrouter.New(config.GetLbrynetServers()) + handler := sdkrouter.Middleware(rt)(http.HandlerFunc(Handle)) + handler.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) - err := json.Unmarshal(rr.Body.Bytes(), &parsedResponse) + var parsedResponse jsonrpc.RPCResponse + err = json.Unmarshal(rr.Body.Bytes(), &parsedResponse) require.NoError(t, err) - assert.Equal(t, 0, apiCalls) } diff --git a/app/proxy/main_test.go b/app/proxy/main_test.go index 941b13e3..e24ac1b5 100644 --- a/app/proxy/main_test.go +++ b/app/proxy/main_test.go @@ -2,30 +2,19 @@ package proxy import ( "math/rand" - "net/http" - "net/http/httptest" "os" "testing" "time" - "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/config" - "github.com/lbryio/lbrytv/internal/responses" "github.com/lbryio/lbrytv/internal/storage" ) -const dummyUserID = 751365 -const dummyServerURL = "http://127.0.0.1:59999" -const proxySuffix = "/api/v1/proxy" const testSetupWait = 200 * time.Millisecond -var svc *Service - func TestMain(m *testing.M) { rand.Seed(time.Now().UnixNano()) - svc = NewService(sdkrouter.New(config.GetLbrynetServers())) - dbConfig := config.GetDatabase() params := storage.ConnParams{ Connection: dbConfig.Connection, @@ -37,31 +26,10 @@ func TestMain(m *testing.M) { defer connCleanup() - code := m.Run() - - os.Exit(code) + os.Exit(m.Run()) } func testFuncSetup() { storage.Conn.Truncate([]string{"users"}) time.Sleep(testSetupWait) } - -func testFuncTeardown() { - -} - -func launchDummyAPIServer(response []byte) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - responses.AddJSONContentType(w) - w.Write(response) - })) -} - -func launchDummyAPIServerDelayed(response []byte, delayMsec time.Duration) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(delayMsec * time.Millisecond) - responses.AddJSONContentType(w) - w.Write(response) - })) -} diff --git a/app/proxy/query.go b/app/proxy/query.go index de26d16b..ac92432b 100644 --- a/app/proxy/query.go +++ b/app/proxy/query.go @@ -6,39 +6,26 @@ import ( "fmt" "strings" - ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" "github.com/lbryio/lbrytv/internal/monitor" + "github.com/lbryio/lbrytv/internal/responses" "github.com/ybbus/jsonrpc" ) // Query is a wrapper around client JSON-RPC query for easier (un)marshaling and processing. type Query struct { - Request *jsonrpc.RPCRequest - rawRequest []byte - walletID string + Request *jsonrpc.RPCRequest + WalletID string } // NewQuery initializes Query object with JSON-RPC request supplied as bytes. // The object is immediately usable and returns an error in case request parsing fails. -func NewQuery(r []byte) (*Query, error) { - q := &Query{rawRequest: r, Request: &jsonrpc.RPCRequest{}} - err := q.unmarshal() - if err != nil { - return nil, NewJSONParseError(err) +func NewQuery(req *jsonrpc.RPCRequest) (*Query, error) { + if strings.TrimSpace(req.Method) == "" { + return nil, errors.New("no method in request") } - return q, nil -} -func (q *Query) unmarshal() error { - err := json.Unmarshal(q.rawRequest, q.Request) - if err != nil { - return err - } - if strings.TrimSpace(q.Request.Method) == "" { - return errors.New("invalid JSON-RPC request") - } - return nil + return &Query{Request: req}, nil } func (q *Query) validate() error { @@ -52,15 +39,15 @@ func (q *Query) validate() error { } } - if !methodInList(q.Method(), relaxedMethods) { - if q.walletID == "" { - return NewInvalidParamsError(errors.New("account identifier required")) + if MethodNeedsAuth(q.Method()) { + if q.WalletID == "" { + return NewAuthRequiredError(errors.New(responses.AuthRequiredErrorMessage)) } if p := q.ParamsAsMap(); p != nil { - p[paramWalletID] = q.walletID + p[paramWalletID] = q.WalletID q.Request.Params = p } else { - q.Request.Params = map[string]interface{}{paramWalletID: q.walletID} + q.Request.Params = map[string]interface{}{paramWalletID: q.WalletID} } } @@ -85,11 +72,6 @@ func (q *Query) ParamsAsMap() map[string]interface{} { return nil } -// ParamsToStruct returns query params parsed into a supplied structure. -func (q *Query) ParamsToStruct(targetStruct interface{}) error { - return ljsonrpc.Decode(q.Params(), targetStruct) -} - // cacheHit returns true if we got a resolve query with more than `cacheResolveLongerThan` urls in it. func (q *Query) isCacheable() bool { if q.Method() == MethodResolve && q.Params() != nil { @@ -112,10 +94,6 @@ func (q *Query) newResponse() *jsonrpc.RPCResponse { } } -func (q *Query) SetWalletID(id string) { - q.walletID = id -} - // cacheHit returns cached response or nil in case it's a miss or query shouldn't be cacheable. func (q *Query) cacheHit() *jsonrpc.RPCResponse { if !q.isCacheable() { @@ -154,6 +132,10 @@ func (q *Query) predefinedResponse() *jsonrpc.RPCResponse { } } +func MethodNeedsAuth(method string) bool { + return !methodInList(method, relaxedMethods) +} + func methodInList(method string, checkMethods []string) bool { for _, m := range checkMethods { if m == method { diff --git a/app/proxy/query_test.go b/app/proxy/query_test.go index cac0db11..5470512f 100644 --- a/app/proxy/query_test.go +++ b/app/proxy/query_test.go @@ -6,21 +6,22 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/ybbus/jsonrpc" ) func TestQueryParamsAsMap(t *testing.T) { - q, err := NewQuery(newRawRequest(t, "version", nil)) + q, err := NewQuery(jsonrpc.NewRequest("version")) require.NoError(t, err) assert.Nil(t, q.ParamsAsMap()) - q, err = NewQuery(newRawRequest(t, "resolve", map[string]string{"urls": "what"})) + q, err = NewQuery(jsonrpc.NewRequest("resolve", map[string]interface{}{"urls": "what"})) require.NoError(t, err) assert.Equal(t, map[string]interface{}{"urls": "what"}, q.ParamsAsMap()) - q, err = NewQuery(newRawRequest(t, "account_balance", nil)) + q, err = NewQuery(jsonrpc.NewRequest("account_balance")) require.NoError(t, err) - q.SetWalletID("123") + q.WalletID = "123" err = q.validate() require.NoError(t, err, errors.Unwrap(err)) assert.Equal(t, map[string]interface{}{"wallet_id": "123"}, q.ParamsAsMap()) @@ -31,7 +32,7 @@ func TestQueryParamsAsMap(t *testing.T) { "gaming", "music", "news", "science", "sports", "technology", }, } - q, err = NewQuery(newRawRequest(t, "claim_search", searchParams)) + q, err = NewQuery(jsonrpc.NewRequest("claim_search", searchParams)) require.NoError(t, err) assert.Equal(t, searchParams, q.ParamsAsMap()) } diff --git a/app/proxy/service.go b/app/proxy/service.go deleted file mode 100644 index ba5543ed..00000000 --- a/app/proxy/service.go +++ /dev/null @@ -1,45 +0,0 @@ -package proxy - -import ( - "net/http" - "time" - - "github.com/lbryio/lbrytv/app/sdkrouter" - - "github.com/ybbus/jsonrpc" -) - -const defaultRPCTimeout = 30 * time.Second - -// Service generates Caller objects and keeps execution time metrics -// for all calls proxied through those objects. -type Service struct { - SDKRouter *sdkrouter.Router - rpcTimeout time.Duration -} - -// NewService is the entry point to proxy module. -// Normally only one instance of Service should be created per running server. -func NewService(router *sdkrouter.Router) *Service { - return &Service{ - SDKRouter: router, - rpcTimeout: defaultRPCTimeout, - } -} - -func (ps *Service) SetRPCTimeout(timeout time.Duration) { - ps.rpcTimeout = timeout -} - -// NewCaller returns an instance of Caller ready to proxy requests. -// Note that `SetWalletID` needs to be called if an authenticated user is making this call. -func (ps *Service) NewCaller(walletID string) *Caller { - endpoint := ps.SDKRouter.GetServer(sdkrouter.UserID(walletID)).Address - return &Caller{ - endpoint: endpoint, - walletID: walletID, - client: jsonrpc.NewClientWithOpts(endpoint, &jsonrpc.RPCClientOpts{ - HTTPClient: &http.Client{Timeout: ps.rpcTimeout}, - }), - } -} diff --git a/app/publish/handler_test.go b/app/publish/handler_test.go index a92df10f..06fa98c1 100644 --- a/app/publish/handler_test.go +++ b/app/publish/handler_test.go @@ -3,6 +3,7 @@ package publish import ( "bytes" "encoding/json" + "errors" "io" "io/ioutil" "mime/multipart" @@ -12,39 +13,48 @@ import ( "path" "testing" - "github.com/lbryio/lbrytv/app/users" + "github.com/lbryio/lbrytv/app/auth" "github.com/lbryio/lbrytv/app/wallet" + "github.com/lbryio/lbrytv/models" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ybbus/jsonrpc" ) type DummyPublisher struct { - called bool - filePath string - accountID string - rawQuery []byte + called bool + filePath string + userID int + rawQuery []byte } -func (p *DummyPublisher) Publish(filePath, accountID string, rawQuery []byte) []byte { +func (p *DummyPublisher) Publish(filePath string, userID int, rawQuery []byte) []byte { p.called = true p.filePath = filePath - p.accountID = accountID + p.userID = userID p.rawQuery = rawQuery return []byte(expectedStreamCreateResponse) } func TestUploadHandler(t *testing.T) { - req := CreatePublishRequest(t, []byte("test file")) - req.Header.Set(wallet.TokenHeader, "uPldrToken") + r := CreatePublishRequest(t, []byte("test file")) + r.Header.Set(wallet.TokenHeader, "uPldrToken") - rr := httptest.NewRecorder() - authenticator := &users.Authenticator{Retriever: users.DummyRetriever("uPldrToken", "UPldrAcc")} publisher := &DummyPublisher{} - pubHandler, err := NewUploadHandler(UploadOpts{Path: os.TempDir(), Publisher: publisher}) - assert.NoError(t, err) + handler := &Handler{ + Publisher: publisher, + UploadPath: os.TempDir(), + } + + retriever := func(token, ip string) (*models.User, error) { + if token == "uPldrToken" { + return &models.User{ID: 20404}, nil + } + return nil, errors.New("error") + } - authenticator.Wrap(pubHandler.Handle).ServeHTTP(rr, req) + rr := httptest.NewRecorder() + auth.Middleware(retriever)(http.HandlerFunc(handler.Handle)).ServeHTTP(rr, r) response := rr.Result() respBody, _ := ioutil.ReadAll(response.Body) @@ -52,48 +62,68 @@ func TestUploadHandler(t *testing.T) { assert.Equal(t, expectedStreamCreateResponse, string(respBody)) require.True(t, publisher.called) - expectedPath := path.Join(os.TempDir(), "UPldrAcc", ".*_lbry_auto_test_file") + expectedPath := path.Join(os.TempDir(), "20404", ".*_lbry_auto_test_file") assert.Regexp(t, expectedPath, publisher.filePath) - assert.Equal(t, "UPldrAcc", publisher.accountID) + assert.Equal(t, 20404, publisher.userID) assert.Equal(t, expectedStreamCreateRequest, string(publisher.rawQuery)) - _, err = os.Stat(publisher.filePath) + _, err := os.Stat(publisher.filePath) assert.True(t, os.IsNotExist(err)) } -func TestUploadHandlerAuthRequired(t *testing.T) { - var rpcResponse jsonrpc.RPCResponse - req := CreatePublishRequest(t, []byte("test file")) +func TestUploadHandlerNoAuthMiddleware(t *testing.T) { + r := CreatePublishRequest(t, []byte("test file")) + r.Header.Set(wallet.TokenHeader, "uPldrToken") + + publisher := &DummyPublisher{} + handler := &Handler{ + Publisher: publisher, + UploadPath: os.TempDir(), + } rr := httptest.NewRecorder() - authenticator := &users.Authenticator{Retriever: users.DummyRetriever("", "")} + assert.Panics(t, func() { + handler.Handle(rr, r) + }) +} + +func TestUploadHandlerAuthRequired(t *testing.T) { + r := CreatePublishRequest(t, []byte("test file")) + publisher := &DummyPublisher{} - pubHandler, err := NewUploadHandler(UploadOpts{Path: os.TempDir(), Publisher: publisher}) - assert.NoError(t, err) + handler := &Handler{ + Publisher: publisher, + UploadPath: os.TempDir(), + } + + retriever := func(token, ip string) (*models.User, error) { + if token == "uPldrToken" { + return &models.User{ID: 20404}, nil + } + return nil, errors.New("error") + } - authenticator.Wrap(pubHandler.Handle).ServeHTTP(rr, req) + rr := httptest.NewRecorder() + auth.Middleware(retriever)(http.HandlerFunc(handler.Handle)).ServeHTTP(rr, r) response := rr.Result() assert.Equal(t, http.StatusOK, response.StatusCode) - err = json.Unmarshal(rr.Body.Bytes(), &rpcResponse) + var rpcResponse jsonrpc.RPCResponse + err := json.Unmarshal(rr.Body.Bytes(), &rpcResponse) require.NoError(t, err) assert.Equal(t, "authentication required", rpcResponse.Error.Message) require.False(t, publisher.called) } func TestUploadHandlerSystemError(t *testing.T) { - var rpcResponse jsonrpc.RPCResponse - // Creating POST data manually here because we need to avoid writer.Close() - data := []byte("test file") - readSeeker := bytes.NewReader(data) + reader := bytes.NewReader([]byte("test file")) body := &bytes.Buffer{} - writer := multipart.NewWriter(body) fileBody, err := writer.CreateFormFile(fileFieldName, "lbry_auto_test_file") require.NoError(t, err) - _, err = io.Copy(fileBody, readSeeker) + _, err = io.Copy(fileBody, reader) require.NoError(t, err) jsonPayload, err := writer.CreateFormField(jsonRPCFieldName) @@ -108,25 +138,28 @@ func TestUploadHandlerSystemError(t *testing.T) { req.Header.Set(wallet.TokenHeader, "uPldrToken") req.Header.Set("Content-Type", writer.FormDataContentType()) - rr := httptest.NewRecorder() - authenticator := &users.Authenticator{Retriever: users.DummyRetriever("uPldrToken", "UPldrAcc")} publisher := &DummyPublisher{} - pubHandler, err := NewUploadHandler(UploadOpts{Path: os.TempDir(), Publisher: publisher}) - assert.NoError(t, err) + handler := &Handler{ + Publisher: publisher, + UploadPath: os.TempDir(), + } + + retriever := func(token, ip string) (*models.User, error) { + if token == "uPldrToken" { + return &models.User{ID: 20404}, nil + } + return nil, errors.New("error") + } - authenticator.Wrap(pubHandler.Handle).ServeHTTP(rr, req) + rr := httptest.NewRecorder() + auth.Middleware(retriever)(http.HandlerFunc(handler.Handle)).ServeHTTP(rr, req) response := rr.Result() require.False(t, publisher.called) assert.Equal(t, http.StatusOK, response.StatusCode) + var rpcResponse jsonrpc.RPCResponse err = json.Unmarshal(rr.Body.Bytes(), &rpcResponse) require.NoError(t, err) assert.Equal(t, "unexpected EOF", rpcResponse.Error.Message) require.False(t, publisher.called) } - -func TestNewUploadHandler(t *testing.T) { - h, err := NewUploadHandler(UploadOpts{}) - assert.Error(t, err, "need either a Service or a Publisher instance") - assert.Nil(t, h) -} diff --git a/app/publish/publish.go b/app/publish/publish.go index fedc12ce..20114a22 100644 --- a/app/publish/publish.go +++ b/app/publish/publish.go @@ -9,102 +9,75 @@ import ( "os" "path" + "github.com/lbryio/lbrytv/app/auth" "github.com/lbryio/lbrytv/app/proxy" - "github.com/lbryio/lbrytv/app/users" - "github.com/lbryio/lbrytv/config" + "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/internal/monitor" + "github.com/lbryio/lbrytv/internal/responses" "github.com/gorilla/mux" ) -// fileFieldName refers to the POST field containing file upload -const fileFieldName = "file" - -// jsonRPCFieldName is a name of the POST field containing JSONRPC request accompanying the uploaded file -const jsonRPCFieldName = "json_payload" +var logger = monitor.NewModuleLogger("publish") -const fileNameParam = "file_path" +const ( + // fileFieldName refers to the POST field containing file upload + fileFieldName = "file" + // jsonRPCFieldName is a name of the POST field containing JSONRPC request accompanying the uploaded file + jsonRPCFieldName = "json_payload" -var logger = monitor.NewModuleLogger("publish") + fileNameParam = "file_path" +) // Publisher is responsible for sending data to lbrynet // and should take file path, account ID and client query as a slice of bytes. type Publisher interface { - Publish(string, string, []byte) []byte + Publish(filePath string, userID int, query []byte) []byte } // LbrynetPublisher is an implementation of SDK publisher. type LbrynetPublisher struct { - *proxy.Service -} - -// UploadHandler glues HTTP uploads to the Publisher. -type UploadHandler struct { - Publisher Publisher - UploadPath string -} - -type UploadOpts struct { - Path string - Publisher Publisher - ProxyService *proxy.Service -} - -// NewUploadHandler returns a HTTP upload handler object. -func NewUploadHandler(opts UploadOpts) (*UploadHandler, error) { - var ( - publisher Publisher - uploadPath string - ) - - if opts.ProxyService != nil { - publisher = &LbrynetPublisher{Service: opts.ProxyService} - } else if opts.Publisher != nil { - publisher = opts.Publisher - } else { - return nil, errors.New("need either a Service or a Publisher instance") - } - - if opts.Path == "" { - uploadPath = config.GetPublishSourceDir() - } else { - uploadPath = opts.Path - } - return &UploadHandler{ - Publisher: publisher, - UploadPath: uploadPath, - }, nil + Router *sdkrouter.Router } // Publish takes a file path, account ID and client JSON-RPC query, // patches the query and sends it to the SDK for processing. // Resulting response is then returned back as a slice of bytes. -func (p *LbrynetPublisher) Publish(filePath, walletID string, rawQuery []byte) []byte { - c := p.Service.NewCaller(walletID) +func (p *LbrynetPublisher) Publish(filePath string, userID int, rawQuery []byte) []byte { + c := proxy.NewCaller(p.Router.GetServer(userID).Address, userID) c.Preprocessor = func(q *proxy.Query) { params := q.ParamsAsMap() params[fileNameParam] = filePath q.Request.Params = params } - r := c.Call(rawQuery) - return r + return c.CallRaw(rawQuery) +} + +// Handler glues HTTP uploads to the Publisher. +type Handler struct { + Publisher Publisher + UploadPath string + InternalAPIHost string } // Handle is where HTTP upload is handled and passed on to Publisher. // It should be wrapped with users.Authenticator.Wrap before it can be used // in a mux.Router. -func (h UploadHandler) Handle(w http.ResponseWriter, r *users.AuthenticatedRequest) { +func (h Handler) Handle(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - if !r.IsAuthenticated() { - if r.AuthFailed() { - w.Write(proxy.NewUnauthorizedError(r.AuthError).JSON()) - } else { - w.Write(proxy.NewAuthRequiredError(errors.New("authentication required")).JSON()) - } + + authResult := auth.FromRequest(r) + + if !authResult.AuthAttempted() { + w.Write(proxy.NewAuthRequiredError(errors.New(responses.AuthRequiredErrorMessage)).JSON()) + return + } + if !authResult.Authenticated() { + w.Write(proxy.NewForbiddenError(authResult.Err).JSON()) return } - f, err := h.saveFile(r) + f, err := h.saveFile(r, authResult.User.ID) if err != nil { logger.Log().Error(err) monitor.CaptureException(err) @@ -112,7 +85,7 @@ func (h UploadHandler) Handle(w http.ResponseWriter, r *users.AuthenticatedReque return } - response := h.Publisher.Publish(f.Name(), r.WalletID, []byte(r.FormValue(jsonRPCFieldName))) + response := h.Publisher.Publish(f.Name(), authResult.User.ID, []byte(r.FormValue(jsonRPCFieldName))) if err := os.Remove(f.Name()); err != nil { monitor.CaptureException(err, map[string]string{"file_path": f.Name()}) @@ -123,38 +96,21 @@ func (h UploadHandler) Handle(w http.ResponseWriter, r *users.AuthenticatedReque // CanHandle checks if http.Request contains POSTed data in an accepted format. // Supposed to be used in gorilla mux router MatcherFunc. -func (h UploadHandler) CanHandle(r *http.Request, _ *mux.RouteMatch) bool { +func (h Handler) CanHandle(r *http.Request, _ *mux.RouteMatch) bool { _, _, err := r.FormFile(fileFieldName) - payload := r.FormValue(jsonRPCFieldName) - return err != http.ErrMissingFile && payload != "" + return err != http.ErrMissingFile && r.FormValue(jsonRPCFieldName) != "" } -// createFile opens an empty file for writing inside the account's designated folder. -// The final file path looks like `/upload_path/{wallet_id}/{random}_filename.ext`, -// where `wallet_id` is local SDK wallet ID and `random` is a random string generated by ioutil. -func (h UploadHandler) createFile(walletID string, origFilename string) (*os.File, error) { - path, err := h.preparePath(walletID) - if err != nil { - return nil, err - } - return ioutil.TempFile(path, fmt.Sprintf("*_%v", origFilename)) -} +func (h Handler) saveFile(r *http.Request, userID int) (*os.File, error) { + log := logger.LogF(monitor.F{"user_id": userID}) -func (h UploadHandler) preparePath(walletID string) (string, error) { - path := path.Join(h.UploadPath, walletID) - err := os.MkdirAll(path, os.ModePerm) - return path, err -} - -func (h UploadHandler) saveFile(r *users.AuthenticatedRequest) (*os.File, error) { - log := logger.LogF(monitor.F{"account_id": r.WalletID}) file, header, err := r.FormFile(fileFieldName) if err != nil { return nil, err } defer file.Close() - f, err := h.createFile(r.WalletID, header.Filename) + f, err := h.createFile(userID, header.Filename) if err != nil { return nil, err } @@ -171,3 +127,15 @@ func (h UploadHandler) saveFile(r *users.AuthenticatedRequest) (*os.File, error) } return f, nil } + +// createFile opens an empty file for writing inside the account's designated folder. +// The final file path looks like `/upload_path/{user_id}/{random}_filename.ext`, +// where `user_id` is user's ID and `random` is a random string generated by ioutil. +func (h Handler) createFile(userID int, origFilename string) (*os.File, error) { + path := path.Join(h.UploadPath, fmt.Sprintf("%d", userID)) + err := os.MkdirAll(path, os.ModePerm) + if err != nil { + return nil, err + } + return ioutil.TempFile(path, fmt.Sprintf("*_%s", origFilename)) +} diff --git a/app/publish/publish_test.go b/app/publish/publish_test.go index 0eeef47c..13446767 100644 --- a/app/publish/publish_test.go +++ b/app/publish/publish_test.go @@ -8,7 +8,6 @@ import ( "path" "testing" - "github.com/lbryio/lbrytv/app/proxy" "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/config" @@ -57,12 +56,9 @@ func TestLbrynetPublisher(t *testing.T) { }`, 751365) }() - config.Override("InternalAPIHost", ts.URL) - defer config.RestoreOverridden() - rt := sdkrouter.New(config.GetLbrynetServers()) - p := &LbrynetPublisher{proxy.NewService(rt)} - u, err := wallet.GetUserWithWallet(rt, authToken, "") + p := &LbrynetPublisher{rt} + u, err := wallet.GetUserWithWallet(rt, ts.URL, authToken, "") require.NoError(t, err) data := []byte("test file") @@ -96,7 +92,7 @@ func TestLbrynetPublisher(t *testing.T) { "id": 1567580184168 }`) - rawResp := p.Publish(path.Join("/storage", path.Base(f.Name())), u.WalletID, query) + rawResp := p.Publish(path.Join("/storage", path.Base(f.Name())), u.ID, query) // This is all we can check for now without running on testnet or crediting some funds to the test account assert.Regexp(t, "Not enough funds to cover this transaction", string(rawResp)) diff --git a/app/sdkrouter/middleware.go b/app/sdkrouter/middleware.go new file mode 100644 index 00000000..8ebb2c6b --- /dev/null +++ b/app/sdkrouter/middleware.go @@ -0,0 +1,30 @@ +package sdkrouter + +import ( + "context" + "net/http" + + "github.com/gorilla/mux" +) + +const ContextKey = "sdkrouter" + +func FromRequest(r *http.Request) *Router { + v := r.Context().Value(ContextKey) + if v == nil { + panic("sdkrouter middleware was not applied") + } + return v.(*Router) +} + +func AddToRequest(rt *Router, fn http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + fn(w, r.Clone(context.WithValue(r.Context(), ContextKey, rt))) + } +} + +func Middleware(rt *Router) mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return AddToRequest(rt, next.ServeHTTP) + } +} diff --git a/app/sdkrouter/sdkrouter.go b/app/sdkrouter/sdkrouter.go index 032e7a09..a367dcaf 100644 --- a/app/sdkrouter/sdkrouter.go +++ b/app/sdkrouter/sdkrouter.go @@ -68,7 +68,7 @@ func (r *Router) GetServer(userID int) *models.LbrynetServer { var sdk *models.LbrynetServer if userID == 0 { - sdk = r.LeastLoaded() + sdk = r.RandomServer() } else { sdk = r.serverForUser(userID) if sdk.Address == "" { diff --git a/app/users/authenticator.go b/app/users/authenticator.go deleted file mode 100644 index d3f236bf..00000000 --- a/app/users/authenticator.go +++ /dev/null @@ -1,80 +0,0 @@ -package users - -import ( - "net/http" - - "github.com/lbryio/lbrytv/app/sdkrouter" - "github.com/lbryio/lbrytv/app/wallet" - "github.com/lbryio/lbrytv/internal/monitor" - "github.com/lbryio/lbrytv/models" -) - -const GenericRetrievalErr = "unable to retrieve user" - -var logger = monitor.NewModuleLogger("auth") - -type AuthenticatedRequest struct { - *http.Request - WalletID string - AuthError error -} - -// AuthFailed is a helper to see if there was an error authenticating user. -func (r *AuthenticatedRequest) AuthFailed() bool { - return r.AuthError != nil -} - -// IsAuthenticated is a helper to see if a user was authenticated. -// If it is false, AuthError might be provided (in case user retriever has errored) -// or be nil if no auth token was present in headers. -func (r *AuthenticatedRequest) IsAuthenticated() bool { - return r.WalletID != "" -} - -// Retriever is an interface for user retrieval by internal-apis auth token -type UserRetriever func(token, metaRemoteIP string) (*models.User, error) - -type Authenticator struct { - Retriever UserRetriever -} - -// NewAuthenticator provides HTTP handler wrapping methods -// and should be initialized with an object that allows user retrieval. -func NewAuthenticator(rt *sdkrouter.Router) *Authenticator { - return &Authenticator{ - Retriever: func(token, metaRemoteIP string) (user *models.User, err error) { - return wallet.GetUserWithWallet(rt, token, metaRemoteIP) - }, - } -} - -// GetWalletID retrieves user token from HTTP headers and subsequently -// an SDK account ID from Retriever. -func (a *Authenticator) GetWalletID(r *http.Request) (string, error) { - if token, ok := r.Header[wallet.TokenHeader]; ok { - ip := GetIPAddressForRequest(r) - user, err := a.Retriever(token[0], ip) - if err != nil { - logger.LogF(monitor.F{"ip": ip}).Debugf("failed to authenticate user") - return "", err - } else if user != nil { - return user.WalletID, nil - } - } - return "", nil -} - -// Wrap result can be supplied to all functions that accept http.HandleFunc, -// supplied function will be wrapped and called with AuthenticatedRequest instead of http.Request. -func (a *Authenticator) Wrap(wrapped func(http.ResponseWriter, *AuthenticatedRequest)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ar := &AuthenticatedRequest{Request: r} - WalletID, err := a.GetWalletID(r) - if err != nil { - ar.AuthError = err - } else { - ar.WalletID = WalletID - } - wrapped(w, ar) - } -} diff --git a/app/users/authenticator_test.go b/app/users/authenticator_test.go deleted file mode 100644 index f97381ae..00000000 --- a/app/users/authenticator_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package users - -import ( - "errors" - "io/ioutil" - "net/http" - "net/http/httptest" - "testing" - - "github.com/lbryio/lbrytv/app/wallet" - "github.com/lbryio/lbrytv/models" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func authedHandler(w http.ResponseWriter, r *AuthenticatedRequest) { - if r.IsAuthenticated() { - w.WriteHeader(http.StatusAccepted) - w.Write([]byte(r.WalletID)) - } else { - w.WriteHeader(http.StatusForbidden) - w.Write([]byte(r.AuthError.Error())) - } -} - -func TestAuthenticator(t *testing.T) { - r, err := http.NewRequest("GET", "/api/proxy", nil) - require.NoError(t, err) - r.Header.Set(wallet.TokenHeader, "XyZ") - r.Header.Set("X-Forwarded-For", "8.8.8.8") - - var receivedRemoteIP string - authenticator := &Authenticator{ - Retriever: func(token, ip string) (*models.User, error) { - receivedRemoteIP = ip - if token == "XyZ" { - return &models.User{WalletID: "aBc"}, nil - } - return nil, errors.New(GenericRetrievalErr) - }, - } - - rr := httptest.NewRecorder() - authenticator.Wrap(authedHandler).ServeHTTP(rr, r) - response := rr.Result() - body, err := ioutil.ReadAll(response.Body) - require.NoError(t, err) - assert.Equal(t, "aBc", string(body)) - assert.Equal(t, "8.8.8.8", receivedRemoteIP) -} - -func TestAuthenticatorFailure(t *testing.T) { - r, err := http.NewRequest("GET", "/api/proxy", nil) - require.NoError(t, err) - r.Header.Set(wallet.TokenHeader, "ALSDJ") - rr := httptest.NewRecorder() - - authenticator := &Authenticator{Retriever: DummyRetriever("XyZ", "")} - - authenticator.Wrap(authedHandler).ServeHTTP(rr, r) - response := rr.Result() - body, err := ioutil.ReadAll(response.Body) - require.NoError(t, err) - assert.Equal(t, GenericRetrievalErr, string(body)) - assert.Equal(t, http.StatusForbidden, response.StatusCode) -} - -func TestAuthenticatorGetWalletIDUnverifiedUser(t *testing.T) { - r, err := http.NewRequest("GET", "/api/proxy", nil) - require.NoError(t, err) - r.Header.Set(wallet.TokenHeader, "zzz") - - a := &Authenticator{Retriever: func(token, ip string) (*models.User, error) { return nil, nil }} - - walletID, err := a.GetWalletID(r) - assert.NoError(t, err) - assert.Equal(t, "", walletID) -} diff --git a/app/users/testing.go b/app/users/testing.go deleted file mode 100644 index 3d047082..00000000 --- a/app/users/testing.go +++ /dev/null @@ -1,16 +0,0 @@ -package users - -import ( - "errors" - - "github.com/lbryio/lbrytv/models" -) - -func DummyRetriever(userToken, walletID string) UserRetriever { - return func(token, ip string) (*models.User, error) { - if userToken == "" || userToken == token { - return &models.User{WalletID: walletID}, nil - } - return nil, errors.New(GenericRetrievalErr) - } -} diff --git a/app/users/testing_test.go b/app/users/testing_test.go deleted file mode 100644 index 5ebc199b..00000000 --- a/app/users/testing_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package users - -import ( - "errors" - "net/http" - "testing" - - "github.com/lbryio/lbrytv/app/wallet" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestTestUserRetrieverGetWalletID(t *testing.T) { - testAuth := &Authenticator{Retriever: DummyRetriever("", "123")} - r, err := http.NewRequest("GET", "/", nil) - require.NoError(t, err) - r.Header.Set(wallet.TokenHeader, "XyZ") - a, err := testAuth.GetWalletID(r) - assert.NoError(t, err) - assert.Equal(t, "123", a) - - r, _ = http.NewRequest("GET", "/", nil) - r.Header.Set(wallet.TokenHeader, "aBc") - a, err = testAuth.GetWalletID(r) - assert.NoError(t, err) - assert.Equal(t, "123", a) - - testAuth = &Authenticator{Retriever: DummyRetriever("XyZ", "123")} - r, _ = http.NewRequest("GET", "/", nil) - r.Header.Set(wallet.TokenHeader, "XyZ") - a, err = testAuth.GetWalletID(r) - assert.NoError(t, err) - assert.Equal(t, "123", a) - - r, _ = http.NewRequest("GET", "/", nil) - r.Header.Set(wallet.TokenHeader, "aBc") - a, err = testAuth.GetWalletID(r) - assert.Equal(t, errors.New(GenericRetrievalErr), err) - assert.Equal(t, "", a) - - r, _ = http.NewRequest("GET", "/", nil) - a, err = testAuth.GetWalletID(r) - assert.NoError(t, err) - assert.Equal(t, "", a) - -} diff --git a/app/wallet/wallet.go b/app/wallet/wallet.go index 2103e5c8..3386ffba 100644 --- a/app/wallet/wallet.go +++ b/app/wallet/wallet.go @@ -7,7 +7,6 @@ import ( ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" "github.com/lbryio/lbrytv/app/sdkrouter" - "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/lbrynet" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/models" @@ -28,10 +27,10 @@ const pgUniqueConstraintViolation = "23505" // Retrieve gets user by internal-apis auth token. If the user does not have a wallet yet, they // are assigned an SDK and a wallet is created for them on that SDK. -func GetUserWithWallet(rt *sdkrouter.Router, token, metaRemoteIP string) (*models.User, error) { +func GetUserWithWallet(rt *sdkrouter.Router, internalAPIHost, token, metaRemoteIP string) (*models.User, error) { log := logger.LogF(monitor.F{monitor.TokenF: token}) - remoteUser, err := getRemoteUser(config.GetInternalAPIHost(), token, metaRemoteIP) + remoteUser, err := getRemoteUser(internalAPIHost, token, metaRemoteIP) if err != nil { msg := "cannot authenticate user with internal-apis: %v" log.Errorf(msg, err) diff --git a/app/wallet/wallet_test.go b/app/wallet/wallet_test.go index ca96cecc..c863f9f7 100644 --- a/app/wallet/wallet_test.go +++ b/app/wallet/wallet_test.go @@ -17,7 +17,6 @@ import ( "github.com/lbryio/lbrytv/internal/test" "github.com/lbryio/lbrytv/models" - jsonrpc2 "github.com/lbryio/lbry.go/v2/extras/jsonrpc" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -45,7 +44,7 @@ func setupDBTables() { storage.Conn.Truncate([]string{"users"}) } -func setupCleanupDummyUser(rt *sdkrouter.Router) func() { +func dummyAPI(rt *sdkrouter.Router) (string, func()) { reqChan := test.ReqChan() ts := test.MockHTTPServer(reqChan) go func() { @@ -63,11 +62,8 @@ func setupCleanupDummyUser(rt *sdkrouter.Router) func() { } }() - config.Override("InternalAPIHost", ts.URL) - - return func() { + return ts.URL, func() { ts.Close() - config.RestoreOverridden() UnloadWallet(rt.GetServer(dummyUserID).Address, dummyUserID) } } @@ -75,10 +71,11 @@ func setupCleanupDummyUser(rt *sdkrouter.Router) func() { func TestWalletServiceRetrieveNewUser(t *testing.T) { rt := sdkrouter.New(config.GetLbrynetServers()) setupDBTables() - defer setupCleanupDummyUser(rt)() + url, cleanup := dummyAPI(rt) + defer cleanup() wid := sdkrouter.WalletID(dummyUserID) - u, err := GetUserWithWallet(rt, "abc", "") + u, err := GetUserWithWallet(rt, url, "abc", "") require.NoError(t, err, errors.Unwrap(err)) require.NotNil(t, u) require.Equal(t, wid, u.WalletID) @@ -87,7 +84,7 @@ func TestWalletServiceRetrieveNewUser(t *testing.T) { require.NoError(t, err) assert.EqualValues(t, 1, count) - u, err = GetUserWithWallet(rt, "abc", "") + u, err = GetUserWithWallet(rt, url, "abc", "") require.NoError(t, err, errors.Unwrap(err)) require.Equal(t, wid, u.WalletID) } @@ -103,11 +100,8 @@ func TestWalletServiceRetrieveNonexistentUser(t *testing.T) { "data": null }` - config.Override("InternalAPIHost", ts.URL) - defer config.RestoreOverridden() - rt := sdkrouter.New(config.GetLbrynetServers()) - u, err := GetUserWithWallet(rt, "non-existent-token", "") + u, err := GetUserWithWallet(rt, ts.URL, "non-existent-token", "") require.Error(t, err) require.Nil(t, u) assert.Equal(t, "cannot authenticate user with internal-apis: could not authenticate user", err.Error()) @@ -116,13 +110,14 @@ func TestWalletServiceRetrieveNonexistentUser(t *testing.T) { func TestWalletServiceRetrieveExistingUser(t *testing.T) { rt := sdkrouter.New(config.GetLbrynetServers()) setupDBTables() - defer setupCleanupDummyUser(rt)() + url, cleanup := dummyAPI(rt) + defer cleanup() - u, err := GetUserWithWallet(rt, "abc", "") + u, err := GetUserWithWallet(rt, url, "abc", "") require.NoError(t, err) require.NotNil(t, u) - u, err = GetUserWithWallet(rt, "abc", "") + u, err = GetUserWithWallet(rt, url, "abc", "") require.NoError(t, err) assert.EqualValues(t, dummyUserID, u.ID) @@ -152,15 +147,12 @@ func TestWalletServiceRetrieveExistingUserMissingWalletID(t *testing.T) { }`, userID) }() - config.Override("InternalAPIHost", ts.URL) - defer config.RestoreOverridden() - rt := sdkrouter.New(config.GetLbrynetServers()) u, err := createDBUser(userID) require.NoError(t, err) require.NotNil(t, u) - u, err = GetUserWithWallet(rt, "abc", "") + u, err = GetUserWithWallet(rt, ts.URL, "abc", "") require.NoError(t, err) assert.NotEqual(t, "", u.WalletID) } @@ -179,11 +171,8 @@ func TestWalletServiceRetrieveNoVerifiedEmail(t *testing.T) { } }` - config.Override("InternalAPIHost", ts.URL) - defer config.RestoreOverridden() - rt := sdkrouter.New(config.GetLbrynetServers()) - u, err := GetUserWithWallet(rt, "abc", "") + u, err := GetUserWithWallet(rt, ts.URL, "abc", "") assert.NoError(t, err) assert.Nil(t, u) } @@ -207,9 +196,6 @@ func BenchmarkWalletCommands(b *testing.B) { }`, req.R.PostFormValue("auth_token")) }() - config.Override("InternalAPIHost", ts.URL) - defer config.RestoreOverridden() - walletsNum := 60 users := make([]*models.User, walletsNum) rt := sdkrouter.New(config.GetLbrynetServers()) @@ -223,7 +209,7 @@ func BenchmarkWalletCommands(b *testing.B) { for i := 0; i < walletsNum; i++ { uid := int(rand.Int31()) - u, err := GetUserWithWallet(rt, fmt.Sprintf("%v", uid), "") + u, err := GetUserWithWallet(rt, ts.URL, fmt.Sprintf("%d", uid), "") require.NoError(b, err, errors.Unwrap(err)) require.NotNil(b, u) users[i] = u @@ -265,20 +251,19 @@ func TestCreateWalletLoadWallet(t *testing.T) { rand.Seed(time.Now().UnixNano()) userID := rand.Int() addr := test.RandServerAddress(t) - client := jsonrpc2.NewClient(addr) - wallet, err := createWallet(client, userID) + wallet, err := createWallet(addr, userID) require.NoError(t, err) assert.Equal(t, wallet.ID, sdkrouter.WalletID(userID)) - wallet, err = createWallet(client, userID) + wallet, err = createWallet(addr, userID) require.NotNil(t, err) assert.True(t, errors.Is(err, lbrynet.ErrWalletExists)) err = UnloadWallet(addr, userID) require.NoError(t, err) - wallet, err = loadWallet(client, userID) + wallet, err = loadWallet(addr, userID) require.NoError(t, err) assert.Equal(t, wallet.ID, sdkrouter.WalletID(userID)) } diff --git a/cmd/serve.go b/cmd/serve.go index b42239c8..5edf152a 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -7,7 +7,6 @@ import ( "os" "time" - "github.com/lbryio/lbrytv/app/proxy" "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/server" @@ -23,10 +22,7 @@ var rootCmd = &cobra.Command{ sdkRouter := sdkrouter.New(config.GetLbrynetServers()) go sdkRouter.WatchLoad() - s := server.NewServer(server.Options{ - Address: config.GetAddress(), - ProxyService: proxy.NewService(sdkRouter), - }) + s := server.NewServer(config.GetAddress(), sdkRouter) err := s.Start() if err != nil { log.Fatal(err) diff --git a/go.mod b/go.mod index d354ccde..41973692 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,7 @@ module github.com/lbryio/lbrytv require ( github.com/aws/aws-sdk-go v1.23.19 // indirect + github.com/davecgh/go-spew v1.1.1 github.com/getsentry/sentry-go v0.4.0 github.com/gobuffalo/packr v1.30.1 // indirect github.com/gobuffalo/packr/v2 v2.7.1 diff --git a/internal/environment/environment.go b/internal/environment/environment.go deleted file mode 100644 index e0fae9b1..00000000 --- a/internal/environment/environment.go +++ /dev/null @@ -1,26 +0,0 @@ -package environment - -import ( - "github.com/lbryio/lbrytv/app/proxy" - "github.com/lbryio/lbrytv/config" - "github.com/lbryio/lbrytv/internal/monitor" -) - -type Env struct { - *monitor.ModuleLogger - *config.ConfigWrapper - - proxy *proxy.Service -} - -func NewEnvironment(logger *monitor.ModuleLogger, config *config.ConfigWrapper, ps *proxy.Service) *Env { - if logger == nil { - logger = &monitor.ModuleLogger{} - } - - return &Env{ModuleLogger: logger, ConfigWrapper: config, proxy: ps} -} - -func Null() *Env { - return NewEnvironment(nil, nil, nil) -} diff --git a/app/users/helpers.go b/internal/ip/ip.go similarity index 89% rename from app/users/helpers.go rename to internal/ip/ip.go index 8da2fbfa..e80d9d2b 100644 --- a/app/users/helpers.go +++ b/internal/ip/ip.go @@ -1,4 +1,4 @@ -package users +package ip import ( "bytes" @@ -67,28 +67,28 @@ func IsPrivateSubnet(ipAddress net.IP) bool { } // GetIPAddressForRequest returns the real IP address of the request -func GetIPAddressForRequest(r *http.Request) string { +func AddressForRequest(r *http.Request) string { for _, h := range []string{"X-Forwarded-For", "X-Real-Ip"} { addresses := strings.Split(r.Header.Get(h), ",") // march from right to left until we get a public address // that will be the address right before our proxy. for i := len(addresses) - 1; i >= 0; i-- { - ip := strings.TrimSpace(addresses[i]) + addr := strings.TrimSpace(addresses[i]) // header can contain spaces too, strip those out. - realIP := net.ParseIP(ip) + realIP := net.ParseIP(addr) if !realIP.IsGlobalUnicast() || IsPrivateSubnet(realIP) { // bad address, go to next continue } - return ip + return addr } } ipParts := strings.Split(r.RemoteAddr, ":") - ip := strings.Join(ipParts[:len(ipParts)-1], ":") + addr := strings.Join(ipParts[:len(ipParts)-1], ":") - if ip == "[::1]" { + if addr == "[::1]" { return "127.0.0.1" } - return ip + return addr } diff --git a/internal/metrics/routes_test.go b/internal/metrics/routes_test.go index 902fdc2d..62a75450 100644 --- a/internal/metrics/routes_test.go +++ b/internal/metrics/routes_test.go @@ -5,10 +5,8 @@ import ( "net/http/httptest" "testing" - "github.com/lbryio/lbrytv/api" - "github.com/lbryio/lbrytv/app/proxy" - "github.com/gorilla/mux" + "github.com/lbryio/lbrytv/api" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -57,7 +55,7 @@ func testMetricUIEvent(t *testing.T, method, name, value string) *httptest.Respo req.URL.RawQuery = q.Encode() r := mux.NewRouter() - api.InstallRoutes(proxy.NewService(nil), r) + api.InstallRoutes(r, nil) rr := httptest.NewRecorder() r.ServeHTTP(rr, req) return rr diff --git a/internal/monitor/monitor.go b/internal/monitor/monitor.go index f94fc70f..09f67795 100644 --- a/internal/monitor/monitor.go +++ b/internal/monitor/monitor.go @@ -107,13 +107,13 @@ func NewProxyLogger() *ProxyLogger { return &l } -func (l *ProxyLogger) LogSuccessfulQuery(method, endpoint, wallet string, time float64, params interface{}, response interface{}) { +func (l *ProxyLogger) LogSuccessfulQuery(method, endpoint string, userID int, time float64, params interface{}, response interface{}) { fields := logrus.Fields{ "method": method, "duration": time, "params": params, "endpoint": endpoint, - "wallet": wallet, + "user_id": userID, } if config.ShouldLogResponses() { fields["response"] = response @@ -122,13 +122,13 @@ func (l *ProxyLogger) LogSuccessfulQuery(method, endpoint, wallet string, time f } -func (l *ProxyLogger) LogFailedQuery(method, endpoint, wallet string, time float64, params interface{}, errorResponse interface{}) { +func (l *ProxyLogger) LogFailedQuery(method, endpoint string, userID int, time float64, params interface{}, errorResponse interface{}) { l.entry.WithFields(logrus.Fields{ "method": method, "duration": time, "params": params, "endpoint": endpoint, - "wallet": wallet, + "user_id": userID, "response": errorResponse, }).Error("error from the target endpoint") } diff --git a/internal/monitor/monitor_test.go b/internal/monitor/monitor_test.go index 9b28c9cf..314235cf 100644 --- a/internal/monitor/monitor_test.go +++ b/internal/monitor/monitor_test.go @@ -72,13 +72,13 @@ func TestLogSuccessfulQueryWithResponse(t *testing.T) { }, } - l.LogSuccessfulQuery("resolve", "sdk1.local", "xx.123.wallet", 0.025, map[string]string{"urls": "one"}, response) + l.LogSuccessfulQuery("resolve", "sdk1.local", 123, 0.025, map[string]string{"urls": "one"}, response) require.Equal(t, 1, len(hook.Entries)) require.Equal(t, log.InfoLevel, hook.LastEntry().Level) require.Equal(t, "resolve", hook.LastEntry().Data["method"]) require.Equal(t, "sdk1.local", hook.LastEntry().Data["endpoint"]) - require.Equal(t, "xx.123.wallet", hook.LastEntry().Data["wallet"]) + require.Equal(t, 123, hook.LastEntry().Data["user_id"]) require.Equal(t, map[string]string{"urls": "one"}, hook.LastEntry().Data["params"]) require.Equal(t, 0.025, hook.LastEntry().Data["duration"]) require.Equal(t, response, hook.LastEntry().Data["response"]) @@ -98,13 +98,13 @@ func TestLogFailedQuery(t *testing.T) { Message: "Method Not Found", } queryParams := map[string]string{"param1": "value1"} - l.LogFailedQuery("unknown_method", "sdk2.local", "xx.566.wallet", 2.34, queryParams, response) + l.LogFailedQuery("unknown_method", "sdk2.local", 566, 2.34, queryParams, response) require.Equal(t, 1, len(hook.Entries)) require.Equal(t, log.ErrorLevel, hook.LastEntry().Level) require.Equal(t, "unknown_method", hook.LastEntry().Data["method"]) require.Equal(t, "sdk2.local", hook.LastEntry().Data["endpoint"]) - require.Equal(t, "xx.566.wallet", hook.LastEntry().Data["wallet"]) + require.Equal(t, 566, hook.LastEntry().Data["user_id"]) require.Equal(t, queryParams, hook.LastEntry().Data["params"]) require.Equal(t, response, hook.LastEntry().Data["response"]) require.Equal(t, 2.34, hook.LastEntry().Data["duration"]) diff --git a/internal/monitor/sentry.go b/internal/monitor/sentry.go index df1ad30e..f2c85296 100644 --- a/internal/monitor/sentry.go +++ b/internal/monitor/sentry.go @@ -4,12 +4,13 @@ import ( "fmt" "github.com/lbryio/lbrytv/config" + "github.com/lbryio/lbrytv/internal/responses" "github.com/getsentry/sentry-go" ) var IgnoredExceptions = []string{ - "account identifier required", + responses.AuthRequiredErrorMessage, } func configureSentry(release, env string) { diff --git a/internal/responses/responses.go b/internal/responses/responses.go index 218f7ada..5e16eaf0 100644 --- a/internal/responses/responses.go +++ b/internal/responses/responses.go @@ -4,6 +4,8 @@ import ( "net/http" ) +const AuthRequiredErrorMessage = "authentication required" + // AddJSONContentType prepares HTTP response writer for JSON content-type. func AddJSONContentType(w http.ResponseWriter) { w.Header().Add("content-type", "application/json; charset=utf-8") diff --git a/internal/status/status.go b/internal/status/status.go index 0621464c..926d20c8 100644 --- a/internal/status/status.go +++ b/internal/status/status.go @@ -30,7 +30,6 @@ const ( StatusNotReady = "not_ready" StatusOffline = "offline" StatusFailing = "failing" - SDKRouterContextKey = "sdkrouter" statusCacheValidity = 120 * time.Second ) @@ -60,7 +59,7 @@ func GetStatus(w http.ResponseWriter, req *http.Request) { } failureDetected := false - sdks := req.Context().Value(SDKRouterContextKey).(*sdkrouter.Router).GetAll() + sdks := sdkrouter.FromRequest(req).GetAll() for _, s := range sdks { srv := ServerItem{Address: s.Address, Status: StatusOK} services["lbrynet"] = append(services["lbrynet"], srv) diff --git a/server/server.go b/server/server.go index 8198fc7b..a571573c 100644 --- a/server/server.go +++ b/server/server.go @@ -9,7 +9,7 @@ import ( "time" "github.com/lbryio/lbrytv/api" - "github.com/lbryio/lbrytv/app/proxy" + "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/internal/monitor" "github.com/gorilla/mux" @@ -19,63 +19,46 @@ var logger = monitor.NewModuleLogger("server") // Server holds entities that can be used to control the web server type Server struct { - defaultHeaders map[string]string - proxyService *proxy.Service - stopChan chan os.Signal - stopWait time.Duration - address string - router *mux.Router - listener *http.Server -} - -// Options holds basic web server settings. -type Options struct { - Address string - ProxyService *proxy.Service - StopWaitSeconds int + address string + listener *http.Server + stopChan chan os.Signal + stopWait time.Duration } // NewServer returns a server initialized with settings from supplied options. -func NewServer(opts Options) *Server { - s := &Server{ - proxyService: opts.ProxyService, - address: opts.Address, - stopWait: 15 * time.Second, - stopChan: make(chan os.Signal), - defaultHeaders: map[string]string{ - "Server": "api.lbry.tv", - "Access-Control-Allow-Origin": "*", - }, - } - if opts.StopWaitSeconds != 0 { - s.stopWait = time.Duration(opts.StopWaitSeconds) * time.Second - } - +func NewServer(address string, sdkRouter *sdkrouter.Router) *Server { r := mux.NewRouter() - api.InstallRoutes(s.proxyService, r) + api.InstallRoutes(r, sdkRouter) r.Use(monitor.ErrorLoggingMiddleware) - r.Use(s.defaultHeadersMiddleware) - s.router = r - - s.listener = &http.Server{ - Addr: s.address, - Handler: s.router, - // Can't have WriteTimeout set for streaming endpoints - WriteTimeout: 0, - IdleTimeout: 0, - ReadHeaderTimeout: 10 * time.Second, + r.Use(defaultHeadersMiddleware(map[string]string{ + "Server": "api.lbry.tv", + "Access-Control-Allow-Origin": "*", + })) + + return &Server{ + address: address, + stopWait: 15 * time.Second, + stopChan: make(chan os.Signal), + listener: &http.Server{ + Addr: address, + Handler: r, + // Can't have WriteTimeout set for streaming endpoints + WriteTimeout: 0, + IdleTimeout: 0, + ReadHeaderTimeout: 10 * time.Second, + }, } - - return s } -func (s *Server) defaultHeadersMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for k, v := range s.defaultHeaders { - w.Header().Set(k, v) - } - next.ServeHTTP(w, r) - }) +func defaultHeadersMiddleware(defaultHeaders map[string]string) mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for k, v := range defaultHeaders { + w.Header().Set(k, v) + } + next.ServeHTTP(w, r) + }) + } } // Start starts a http server and returns immediately. @@ -87,7 +70,7 @@ func (s *Server) Start() error { } } }() - logger.Log().Infof("http server listening on %v", s.address) + logger.Log().Infof("http server listening on %v", s.listener.Addr) return nil } diff --git a/server/server_test.go b/server/server_test.go index 7c8061bc..5ad37677 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -6,7 +6,6 @@ import ( "testing" "time" - "github.com/lbryio/lbrytv/app/proxy" "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/config" @@ -15,10 +14,7 @@ import ( ) func TestStartAndServeUntilShutdown(t *testing.T) { - server := NewServer(Options{ - Address: "localhost:40080", - ProxyService: proxy.NewService(sdkrouter.New(config.GetLbrynetServers())), - }) + server := NewServer("localhost:40080", sdkrouter.New(config.GetLbrynetServers())) server.Start() go server.ServeUntilShutdown() @@ -47,10 +43,7 @@ func TestHeaders(t *testing.T) { response *http.Response ) - server := NewServer(Options{ - Address: "localhost:40080", - ProxyService: proxy.NewService(sdkrouter.New(config.GetLbrynetServers())), - }) + server := NewServer("localhost:40080", sdkrouter.New(config.GetLbrynetServers())) server.Start() go server.ServeUntilShutdown() From 4836971c85c061733b25c862438c3c0c9ef66ed3 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Tue, 14 Apr 2020 17:26:10 -0400 Subject: [PATCH 07/18] clarify what it means to send empty string to chan --- app/proxy/caller_test.go | 6 +++--- app/proxy/client_test.go | 4 ++-- internal/test/test.go | 2 ++ 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/app/proxy/caller_test.go b/app/proxy/caller_test.go index fb2e5ad1..c25efc59 100644 --- a/app/proxy/caller_test.go +++ b/app/proxy/caller_test.go @@ -119,7 +119,7 @@ func TestCallerCallRelaxedMethods(t *testing.T) { if m == MethodStatus { return } - srv.NextResponse <- "" + srv.RespondWithNothing() caller.Call(jsonrpc.NewRequest(m)) receivedRequest := <-reqChan expectedRequest := test.ReqToStr(t, jsonrpc.RPCRequest{ @@ -153,7 +153,7 @@ func TestCallerCallAttachesWalletID(t *testing.T) { reqChan := test.ReqChan() srv := test.MockHTTPServer(reqChan) defer srv.Close() - srv.NextResponse <- "" + srv.RespondWithNothing() caller := NewCaller(srv.URL, dummyUserID) caller.Call(jsonrpc.NewRequest("channel_create", map[string]interface{}{"name": "test", "bid": "0.1"})) receivedRequest := <-reqChan @@ -187,7 +187,7 @@ func TestCallerSetPreprocessor(t *testing.T) { } } - srv.NextResponse <- "" + srv.RespondWithNothing() c.Call(jsonrpc.NewRequest(relaxedMethods[0])) req := <-reqChan diff --git a/app/proxy/client_test.go b/app/proxy/client_test.go index 3de5a722..b68a71d6 100644 --- a/app/proxy/client_test.go +++ b/app/proxy/client_test.go @@ -55,7 +55,7 @@ func TestClientCallDoesNotReloadWalletAfterOtherErrors(t *testing.T) { Message: "Couldn't find wallet: //", }, }) - srv.NextResponse <- "" // for the wallet_add call + srv.RespondWithNothing() // for the wallet_add call srv.NextResponse <- test.ResToStr(t, jsonrpc.RPCResponse{ JSONRPC: "2.0", Error: &jsonrpc.RPCError{ @@ -88,7 +88,7 @@ func TestClientCallDoesNotReloadWalletIfAlreadyLoaded(t *testing.T) { Message: "Couldn't find wallet: //", }, }) - srv.NextResponse <- "" // for the wallet_add call + srv.RespondWithNothing() // for the wallet_add call srv.NextResponse <- test.ResToStr(t, jsonrpc.RPCResponse{ JSONRPC: "2.0", Error: &jsonrpc.RPCError{ diff --git a/internal/test/test.go b/internal/test/test.go index 88b2276b..75169163 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -17,6 +17,8 @@ type MockServer struct { NextResponse chan<- string } +func (m *MockServer) RespondWithNothing() { m.NextResponse <- "" } + type Request struct { R *http.Request W http.ResponseWriter From cbd60687e58e37b263590d635be5db56d164d228 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Wed, 15 Apr 2020 17:12:34 -0400 Subject: [PATCH 08/18] fixed a bunch more things with auth. finally dropped sdkrouter.GetServer() --- api/routes.go | 53 +++++++++++-- app/auth/auth.go | 82 +++++++++++++------- app/auth/auth_test.go | 52 +++++++------ app/proxy/accounts_test.go | 17 +---- app/proxy/caller_test.go | 28 +------ app/proxy/client_test.go | 4 +- app/proxy/handlers.go | 13 +++- app/publish/handler_test.go | 120 ++++++++++++++++++------------ app/publish/publish.go | 59 ++++++--------- app/publish/publish_test.go | 34 ++------- app/publish/testing.go | 45 +++++------ app/sdkrouter/concurrency_test.go | 6 +- app/sdkrouter/sdkrouter.go | 83 ++------------------- app/sdkrouter/sdkrouter_test.go | 22 +----- app/wallet/wallet.go | 57 +++++++------- app/wallet/wallet_test.go | 67 +++++++++++------ go.mod | 1 + go.sum | 2 + internal/responses/responses.go | 2 + internal/test/test.go | 70 +++++++++++++++++ 20 files changed, 436 insertions(+), 381 deletions(-) diff --git a/api/routes.go b/api/routes.go index 338c9c9d..7b5acd65 100644 --- a/api/routes.go +++ b/api/routes.go @@ -1,7 +1,10 @@ package api import ( + "encoding/json" + "fmt" "net/http" + "runtime/debug" "strings" "time" @@ -11,20 +14,22 @@ import ( "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/metrics" + "github.com/lbryio/lbrytv/internal/monitor" + "github.com/lbryio/lbrytv/internal/responses" "github.com/lbryio/lbrytv/internal/status" + "github.com/ybbus/jsonrpc" "github.com/gorilla/mux" "github.com/prometheus/client_golang/prometheus/promhttp" ) +var logger = monitor.NewModuleLogger("api") + // InstallRoutes sets up global API handlers func InstallRoutes(r *mux.Router, sdkRouter *sdkrouter.Router) { - upHandler := &publish.Handler{ - Publisher: &publish.LbrynetPublisher{Router: sdkRouter}, - UploadPath: config.GetPublishSourceDir(), - InternalAPIHost: config.GetInternalAPIHost(), - } + upHandler := &publish.Handler{UploadPath: config.GetPublishSourceDir()} + r.Use(recoveryHandler) r.Use(methodTimer) r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { @@ -33,8 +38,8 @@ func InstallRoutes(r *mux.Router, sdkRouter *sdkrouter.Router) { v1Router := r.PathPrefix("/api/v1").Subrouter() v1Router.Use(sdkrouter.Middleware(sdkRouter)) - retriever := auth.AllInOneRetrieverThatNeedsRefactoring(sdkRouter, config.GetInternalAPIHost()) - v1Router.Use(auth.Middleware(retriever)) + authProvider := auth.WalletAndInternalAPIProvider(sdkRouter, config.GetInternalAPIHost()) + v1Router.Use(auth.Middleware(authProvider)) v1Router.HandleFunc("/proxy", proxy.HandleCORS).Methods(http.MethodOptions) v1Router.HandleFunc("/proxy", upHandler.Handle).MatcherFunc(upHandler.CanHandle) v1Router.HandleFunc("/proxy", proxy.Handle) @@ -59,3 +64,37 @@ func methodTimer(next http.Handler) http.Handler { metrics.LbrytvCallDurations.WithLabelValues(path).Observe(time.Since(start).Seconds()) }) } + +func recoveryHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recovered, stack := func() (err error, stack []byte) { + defer func() { + if r := recover(); r != nil { + var ok bool + err, ok = r.(error) + if !ok { + err = fmt.Errorf("%v", r) + } + if !config.IsProduction() { + stack = debug.Stack() + } + } + }() + next.ServeHTTP(w, r) + return err, nil + }() + if recovered != nil { + logger.Log().Errorf("PANIC %v, trace %s", recovered, stack) + responses.AddJSONContentType(w) + rsp, _ := json.Marshal(jsonrpc.RPCResponse{ + JSONRPC: "2.0", + Error: &jsonrpc.RPCError{ + Code: -1, + Message: recovered.Error(), + Data: stack, + }, + }) + w.Write(rsp) + } + }) +} diff --git a/app/auth/auth.go b/app/auth/auth.go index 955deef2..60d9486f 100644 --- a/app/auth/auth.go +++ b/app/auth/auth.go @@ -17,49 +17,81 @@ var logger = monitor.NewModuleLogger("auth") const ContextKey = "user" -type Result struct { - User *models.User - Err error -} - -func (r *Result) AuthAttempted() bool { return r.User != nil || r.Err != nil } -func (r *Result) AuthFailed() bool { return r.Err != nil } -func (r *Result) Authenticated() bool { return r.User != nil } - -func FromRequest(r *http.Request) *Result { +func FromRequest(r *http.Request) Result { v := r.Context().Value(ContextKey) if v == nil { panic("Auth middleware was not applied") } - return v.(*Result) + return v.(Result) } -// Retriever gets a user by hitting internal-api with the provided auth token +// Provider gets a user by hitting internal-api with the provided auth token // and matching the response to a local user. // NOTE: The retrieved user must come with a wallet that's created and ready to use. -type Retriever func(token, metaRemoteIP string) (*models.User, error) +type Provider func(token, metaRemoteIP string) Result -func AllInOneRetrieverThatNeedsRefactoring(rt *sdkrouter.Router, internalAPIHost string) Retriever { - return func(token, metaRemoteIP string) (user *models.User, err error) { - return wallet.GetUserWithWallet(rt, internalAPIHost, token, metaRemoteIP) +func WalletAndInternalAPIProvider(rt *sdkrouter.Router, internalAPIHost string) Provider { + return func(token, metaRemoteIP string) Result { + user, err := wallet.GetUserWithWallet(rt, internalAPIHost, token, metaRemoteIP) + res := NewResult(user, err) + if err == nil && user != nil && !user.LbrynetServerID.IsZero() && user.R != nil && user.R.LbrynetServer != nil { + res.SDKAddress = user.R.LbrynetServer.Address + } + return res } } -func Middleware(retriever Retriever) mux.MiddlewareFunc { +func Middleware(provider Provider) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ar := &Result{} + var res Result if token, ok := r.Header[wallet.TokenHeader]; ok { addr := ip.AddressForRequest(r) - user, err := retriever(token[0], addr) - if err != nil { - logger.LogF(monitor.F{"ip": addr}).Debugf("failed to authenticate user") - ar.Err = err - } else { - ar.User = user + res = provider(token[0], addr) + if res.err != nil { + logger.LogF(monitor.F{"ip": addr}).Debugf("error authenticating user") } + res.authAttempted = true } - next.ServeHTTP(w, r.Clone(context.WithValue(r.Context(), ContextKey, ar))) + next.ServeHTTP(w, r.Clone(context.WithValue(r.Context(), ContextKey, res))) }) } } + +// wish i could make this non-exported, but then you can't create new providers outside the package +// don't make this struct directly. instead use NewResult +type Result struct { + SDKAddress string + + user *models.User + err error + authAttempted bool +} + +func NewResult(user *models.User, err error) Result { + if err != nil { + user = nil // err and user cannot be non-nil at the same time + } + return Result{user: user, err: err} +} + +func (r Result) AuthAttempted() bool { + return r.authAttempted +} + +func (r Result) Authenticated() bool { + return r.user != nil +} + +func (r Result) User() *models.User { + if !r.authAttempted { + return nil + } + return r.user +} +func (r Result) Err() error { + if !r.authAttempted { + return nil + } + return r.err +} diff --git a/app/auth/auth_test.go b/app/auth/auth_test.go index c8da91e6..42529a74 100644 --- a/app/auth/auth_test.go +++ b/app/auth/auth_test.go @@ -24,16 +24,16 @@ func TestMiddleware(t *testing.T) { r.Header.Set("X-Forwarded-For", "8.8.8.8") var receivedRemoteIP string - retriever := func(token, ip string) (*models.User, error) { + provider := func(token, ip string) Result { receivedRemoteIP = ip if token == "secret-token" { - return &models.User{ID: 16595}, nil + return NewResult(&models.User{ID: 16595}, nil) } - return nil, errors.New("error") + return NewResult(nil, errors.New("error")) } rr := httptest.NewRecorder() - Middleware(retriever)(http.HandlerFunc(authChecker)).ServeHTTP(rr, r) + Middleware(provider)(http.HandlerFunc(authChecker)).ServeHTTP(rr, r) response := rr.Result() body, err := ioutil.ReadAll(response.Body) @@ -48,13 +48,13 @@ func TestMiddlewareAuthFailure(t *testing.T) { r.Header.Set(wallet.TokenHeader, "wrong-token") rr := httptest.NewRecorder() - retriever := func(token, ip string) (*models.User, error) { + provider := func(token, ip string) Result { if token == "good-token" { - return &models.User{ID: 1}, nil + return NewResult(&models.User{ID: 1}, nil) } - return nil, errors.New("incorrect token") + return NewResult(nil, errors.New("incorrect token")) } - Middleware(retriever)(http.HandlerFunc(authChecker)).ServeHTTP(rr, r) + Middleware(provider)(http.HandlerFunc(authChecker)).ServeHTTP(rr, r) response := rr.Result() body, err := ioutil.ReadAll(response.Body) @@ -68,13 +68,13 @@ func TestMiddlewareNoAuth(t *testing.T) { require.NoError(t, err) rr := httptest.NewRecorder() - retriever := func(token, ip string) (*models.User, error) { + provider := func(token, ip string) Result { if token == "good-token" { - return &models.User{ID: 1}, nil + return NewResult(&models.User{ID: 1}, nil) } - return nil, errors.New("incorrect token") + return NewResult(nil, errors.New("incorrect token")) } - Middleware(retriever)(http.HandlerFunc(authChecker)).ServeHTTP(rr, r) + Middleware(provider)(http.HandlerFunc(authChecker)).ServeHTTP(rr, r) response := rr.Result() body, err := ioutil.ReadAll(response.Body) @@ -84,7 +84,7 @@ func TestMiddlewareNoAuth(t *testing.T) { } func TestFromRequestSuccess(t *testing.T) { - expected := &Result{Err: errors.New("a test")} + expected := NewResult(nil, errors.New("a test")) ctx := context.WithValue(context.Background(), ContextKey, expected) r, err := http.NewRequestWithContext(ctx, http.MethodPost, "", &bytes.Buffer{}) @@ -92,8 +92,9 @@ func TestFromRequestSuccess(t *testing.T) { results := FromRequest(r) assert.NotNil(t, results) - assert.Equal(t, expected.User, results.User) - assert.Equal(t, expected.Err.Error(), results.Err.Error()) + assert.Equal(t, expected.user, results.user) + assert.Equal(t, expected.err.Error(), results.err.Error()) + assert.False(t, results.AuthAttempted()) } func TestFromRequestFail(t *testing.T) { @@ -104,14 +105,23 @@ func TestFromRequestFail(t *testing.T) { func authChecker(w http.ResponseWriter, r *http.Request) { result := FromRequest(r) - if result.Authenticated() { + if result.user != nil && result.err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("this should never happen")) + return + } + + if !result.AuthAttempted() { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("no auth info")) + } else if result.Authenticated() { w.WriteHeader(http.StatusAccepted) - w.Write([]byte(fmt.Sprintf("%d", result.User.ID))) - } else if result.AuthFailed() { + w.Write([]byte(fmt.Sprintf("%d", result.user.ID))) + } else if result.Err() != nil { w.WriteHeader(http.StatusForbidden) - w.Write([]byte(result.Err.Error())) + w.Write([]byte(result.Err().Error())) } else { - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte("no auth info")) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("no user and no error. what happened?")) } } diff --git a/app/proxy/accounts_test.go b/app/proxy/accounts_test.go index fa1efc53..223f2580 100644 --- a/app/proxy/accounts_test.go +++ b/app/proxy/accounts_test.go @@ -9,10 +9,8 @@ import ( "github.com/lbryio/lbrytv/app/auth" "github.com/lbryio/lbrytv/app/sdkrouter" - "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/test" - "github.com/lbryio/lbrytv/models" ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" @@ -40,13 +38,8 @@ func TestWithWrongAuthToken(t *testing.T) { r.Header.Add("X-Lbry-Auth-Token", "xXxXxXx") rr := httptest.NewRecorder() - rt := sdkrouter.New(config.GetLbrynetServers()) - retriever := func(token, ip string) (*models.User, error) { - return wallet.GetUserWithWallet(rt, ts.URL, token, "") - } - - handler := sdkrouter.Middleware(rt)(auth.Middleware(retriever)(http.HandlerFunc(Handle))) + handler := sdkrouter.Middleware(rt)(auth.Middleware(auth.WalletAndInternalAPIProvider(rt, ts.URL))(http.HandlerFunc(Handle))) handler.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) @@ -65,7 +58,6 @@ func TestWithoutToken(t *testing.T) { require.NoError(t, err) rr := httptest.NewRecorder() - rt := sdkrouter.New(config.GetLbrynetServers()) handler := sdkrouter.Middleware(rt)(http.HandlerFunc(Handle)) handler.ServeHTTP(rr, r) @@ -92,12 +84,11 @@ func TestAccountSpecificWithoutToken(t *testing.T) { require.NoError(t, err) rr := httptest.NewRecorder() - rt := sdkrouter.New(config.GetLbrynetServers()) - retriever := func(token, ip string) (*models.User, error) { - return nil, nil + provider := func(token, ip string) auth.Result { + return auth.NewResult(nil, nil) } - handler := sdkrouter.Middleware(rt)(auth.Middleware(retriever)(http.HandlerFunc(Handle))) + handler := sdkrouter.Middleware(rt)(auth.Middleware(provider)(http.HandlerFunc(Handle))) handler.ServeHTTP(rr, r) require.Equal(t, http.StatusOK, rr.Code) diff --git a/app/proxy/caller_test.go b/app/proxy/caller_test.go index c25efc59..ce79ef3b 100644 --- a/app/proxy/caller_test.go +++ b/app/proxy/caller_test.go @@ -10,7 +10,6 @@ import ( "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/app/wallet" - "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/responses" "github.com/lbryio/lbrytv/internal/test" @@ -31,22 +30,6 @@ func parseRawResponse(t *testing.T, rawCallResponse []byte, v interface{}) { require.NoError(t, err) } -func TestNewCaller(t *testing.T) { - servers := map[string]string{ - "first": "http://lbrynet1", - "second": "http://lbrynet2", - } - rt := sdkrouter.New(servers) - sList := rt.GetAll() - rand.Seed(time.Now().UnixNano()) - for i := 1; i <= 100; i++ { - id := rand.Intn(10^6-10^3) + 10 ^ 3 - wc := NewCaller(rt.GetServer(id).Address, id) - lastDigit := id % 10 - assert.Equal(t, sList[lastDigit%len(sList)].Address, wc.endpoint) - } -} - func TestCallerCallRaw(t *testing.T) { c := NewCaller(test.RandServerAddress(t), 0) for _, rawQ := range []string{``, ` `, `[]`, `[{}]`, `[""]`, `""`, `" "`} { @@ -64,13 +47,11 @@ func TestCallerCallRaw(t *testing.T) { } func TestCallerCallResolve(t *testing.T) { - rt := sdkrouter.New(config.GetLbrynetServers()) - resolvedURL := "what#6769855a9aa43b67086f9ff3c1a5bacb5698a27a" resolvedClaimID := "6769855a9aa43b67086f9ff3c1a5bacb5698a27a" request := jsonrpc.NewRequest("resolve", map[string]interface{}{"urls": resolvedURL}) - rawCallResponse := NewCaller(rt.RandomServer().Address, 0).Call(request) + rawCallResponse := NewCaller(test.RandServerAddress(t), 0).Call(request) var errorResponse jsonrpc.RPCResponse err := json.Unmarshal(rawCallResponse, &errorResponse) @@ -85,15 +66,14 @@ func TestCallerCallResolve(t *testing.T) { func TestCallerCallWalletBalance(t *testing.T) { rand.Seed(time.Now().UnixNano()) dummyUserID := rand.Intn(10^6-10^3) + 10 ^ 3 - rt := sdkrouter.New(config.GetLbrynetServers()) request := jsonrpc.NewRequest("wallet_balance") - result := NewCaller(rt.RandomServer().Address, 0).Call(request) + result := NewCaller(test.RandServerAddress(t), 0).Call(request) assert.Contains(t, string(result), `"message": "authentication required"`) addr := test.RandServerAddress(t) - walletID, err := wallet.Create(addr, dummyUserID) + err := wallet.Create(addr, dummyUserID) require.NoError(t, err) hook := logrusTest.NewLocal(Logger.Logger()) @@ -104,7 +84,7 @@ func TestCallerCallWalletBalance(t *testing.T) { } parseRawResponse(t, result, &accountBalanceResponse) assert.EqualValues(t, "0.0", accountBalanceResponse.Available) - assert.Equal(t, map[string]interface{}{"wallet_id": walletID}, hook.LastEntry().Data["params"]) + assert.Equal(t, map[string]interface{}{"wallet_id": sdkrouter.WalletID(dummyUserID)}, hook.LastEntry().Data["params"]) assert.Equal(t, "wallet_balance", hook.LastEntry().Data["method"]) } diff --git a/app/proxy/client_test.go b/app/proxy/client_test.go index b68a71d6..43178198 100644 --- a/app/proxy/client_test.go +++ b/app/proxy/client_test.go @@ -18,14 +18,14 @@ func TestClientCallDoesReloadWallet(t *testing.T) { dummyUserID := rand.Intn(100) addr := test.RandServerAddress(t) - walletID, err := wallet.Create(addr, dummyUserID) + err := wallet.Create(addr, dummyUserID) require.NoError(t, err) err = wallet.UnloadWallet(addr, dummyUserID) require.NoError(t, err) q, err := NewQuery(jsonrpc.NewRequest("wallet_balance")) require.NoError(t, err) - q.WalletID = walletID + q.WalletID = sdkrouter.WalletID(dummyUserID) c := NewCaller(addr, dummyUserID) r, err := c.callQueryWithRetry(q) diff --git a/app/proxy/handlers.go b/app/proxy/handlers.go index 58e2c076..47b27485 100644 --- a/app/proxy/handlers.go +++ b/app/proxy/handlers.go @@ -43,6 +43,7 @@ func Handle(w http.ResponseWriter, r *http.Request) { } var userID int + var sdkAddress string if MethodNeedsAuth(req.Method) { authResult := auth.FromRequest(r) if !authResult.AuthAttempted() { @@ -50,14 +51,20 @@ func Handle(w http.ResponseWriter, r *http.Request) { return } if !authResult.Authenticated() { - w.Write(NewForbiddenError(authResult.Err).JSON()) + w.Write(NewForbiddenError(authResult.Err()).JSON()) return } - userID = authResult.User.ID + userID = authResult.User().ID + sdkAddress = authResult.SDKAddress } rt := sdkrouter.FromRequest(r) - c := NewCaller(rt.GetServer(userID).Address, userID) + + if sdkAddress == "" { + sdkAddress = rt.RandomServer().Address + } + + c := NewCaller(sdkAddress, userID) w.Write(c.Call(&req)) } diff --git a/app/publish/handler_test.go b/app/publish/handler_test.go index 06fa98c1..7b1ea39a 100644 --- a/app/publish/handler_test.go +++ b/app/publish/handler_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "io/ioutil" "mime/multipart" @@ -14,8 +15,11 @@ import ( "testing" "github.com/lbryio/lbrytv/app/auth" + "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/app/wallet" + "github.com/lbryio/lbrytv/internal/test" "github.com/lbryio/lbrytv/models" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ybbus/jsonrpc" @@ -24,16 +28,8 @@ import ( type DummyPublisher struct { called bool filePath string - userID int - rawQuery []byte -} - -func (p *DummyPublisher) Publish(filePath string, userID int, rawQuery []byte) []byte { - p.called = true - p.filePath = filePath - p.userID = userID - p.rawQuery = rawQuery - return []byte(expectedStreamCreateResponse) + walletID string + rawQuery string } func TestUploadHandler(t *testing.T) { @@ -41,45 +37,58 @@ func TestUploadHandler(t *testing.T) { r.Header.Set(wallet.TokenHeader, "uPldrToken") publisher := &DummyPublisher{} - handler := &Handler{ - Publisher: publisher, - UploadPath: os.TempDir(), - } - retriever := func(token, ip string) (*models.User, error) { + reqChan := test.ReqChan() + ts := test.MockHTTPServer(reqChan) + go func() { + req := <-reqChan + publisher.called = true + rpcReq := test.StrToReq(t, req.Body) + params, ok := rpcReq.Params.(map[string]interface{}) + require.True(t, ok) + publisher.filePath = params["file_path"].(string) + publisher.walletID = params["wallet_id"].(string) + publisher.rawQuery = req.Body + ts.NextResponse <- expectedStreamCreateResponse + }() + + handler := &Handler{UploadPath: os.TempDir()} + + provider := func(token, ip string) auth.Result { if token == "uPldrToken" { - return &models.User{ID: 20404}, nil + res := auth.NewResult(&models.User{ID: 20404}, nil) + res.SDKAddress = ts.URL + return res } - return nil, errors.New("error") + return auth.NewResult(nil, errors.New("error")) } rr := httptest.NewRecorder() - auth.Middleware(retriever)(http.HandlerFunc(handler.Handle)).ServeHTTP(rr, r) + auth.Middleware(provider)(http.HandlerFunc(handler.Handle)).ServeHTTP(rr, r) response := rr.Result() - respBody, _ := ioutil.ReadAll(response.Body) + respBody, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) assert.Equal(t, http.StatusOK, response.StatusCode) - assert.Equal(t, expectedStreamCreateResponse, string(respBody)) + test.AssertJsonEqual(t, expectedStreamCreateResponse, respBody) require.True(t, publisher.called) expectedPath := path.Join(os.TempDir(), "20404", ".*_lbry_auto_test_file") assert.Regexp(t, expectedPath, publisher.filePath) - assert.Equal(t, 20404, publisher.userID) - assert.Equal(t, expectedStreamCreateRequest, string(publisher.rawQuery)) + assert.Equal(t, sdkrouter.WalletID(20404), publisher.walletID) + expectedReq := fmt.Sprintf(expectedStreamCreateRequest, sdkrouter.WalletID(20404), publisher.filePath) + test.AssertJsonEqual(t, expectedReq, publisher.rawQuery) - _, err := os.Stat(publisher.filePath) + _, err = os.Stat(publisher.filePath) assert.True(t, os.IsNotExist(err)) } -func TestUploadHandlerNoAuthMiddleware(t *testing.T) { - r := CreatePublishRequest(t, []byte("test file")) +func TestHandler_NoAuthMiddleware(t *testing.T) { + r, err := http.NewRequest("POST", "/api/v1/proxy", &bytes.Buffer{}) + require.NoError(t, err) r.Header.Set(wallet.TokenHeader, "uPldrToken") - publisher := &DummyPublisher{} - handler := &Handler{ - Publisher: publisher, - UploadPath: os.TempDir(), - } + handler := &Handler{UploadPath: os.TempDir()} rr := httptest.NewRecorder() assert.Panics(t, func() { @@ -87,24 +96,40 @@ func TestUploadHandlerNoAuthMiddleware(t *testing.T) { }) } -func TestUploadHandlerAuthRequired(t *testing.T) { +func TestHandler_NoSDKAddress(t *testing.T) { r := CreatePublishRequest(t, []byte("test file")) + r.Header.Set(wallet.TokenHeader, "x") + rr := httptest.NewRecorder() - publisher := &DummyPublisher{} - handler := &Handler{ - Publisher: publisher, - UploadPath: os.TempDir(), + handler := &Handler{UploadPath: os.TempDir()} + provider := func(token, ip string) auth.Result { + return auth.NewResult(&models.User{ID: 20404}, nil) } - retriever := func(token, ip string) (*models.User, error) { + auth.Middleware(provider)(http.HandlerFunc(handler.Handle)).ServeHTTP(rr, r) + response := rr.Result() + respBody, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, response.StatusCode) + assert.Contains(t, string(respBody), "user does not have sdk address assigned") +} + +func TestHandler_AuthRequired(t *testing.T) { + r := CreatePublishRequest(t, []byte("test file")) + + publisher := &DummyPublisher{} + handler := &Handler{UploadPath: os.TempDir()} + + provider := func(token, ip string) auth.Result { if token == "uPldrToken" { - return &models.User{ID: 20404}, nil + return auth.NewResult(&models.User{ID: 20404}, nil) } - return nil, errors.New("error") + return auth.NewResult(nil, errors.New("error")) } rr := httptest.NewRecorder() - auth.Middleware(retriever)(http.HandlerFunc(handler.Handle)).ServeHTTP(rr, r) + auth.Middleware(provider)(http.HandlerFunc(handler.Handle)).ServeHTTP(rr, r) response := rr.Result() assert.Equal(t, http.StatusOK, response.StatusCode) @@ -128,7 +153,7 @@ func TestUploadHandlerSystemError(t *testing.T) { jsonPayload, err := writer.CreateFormField(jsonRPCFieldName) require.NoError(t, err) - jsonPayload.Write([]byte(expectedStreamCreateRequest)) + jsonPayload.Write([]byte(fmt.Sprintf(expectedStreamCreateRequest, sdkrouter.WalletID(20404), "arst"))) // <--- Not calling writer.Close() here to create an unexpected EOF @@ -139,20 +164,19 @@ func TestUploadHandlerSystemError(t *testing.T) { req.Header.Set("Content-Type", writer.FormDataContentType()) publisher := &DummyPublisher{} - handler := &Handler{ - Publisher: publisher, - UploadPath: os.TempDir(), - } + handler := &Handler{UploadPath: os.TempDir()} - retriever := func(token, ip string) (*models.User, error) { + provider := func(token, ip string) auth.Result { if token == "uPldrToken" { - return &models.User{ID: 20404}, nil + res := auth.NewResult(&models.User{ID: 20404}, nil) + res.SDKAddress = "whatever" + return res } - return nil, errors.New("error") + return auth.NewResult(nil, errors.New("error")) } rr := httptest.NewRecorder() - auth.Middleware(retriever)(http.HandlerFunc(handler.Handle)).ServeHTTP(rr, req) + auth.Middleware(provider)(http.HandlerFunc(handler.Handle)).ServeHTTP(rr, req) response := rr.Result() require.False(t, publisher.called) diff --git a/app/publish/publish.go b/app/publish/publish.go index 20114a22..339e37a8 100644 --- a/app/publish/publish.go +++ b/app/publish/publish.go @@ -11,7 +11,6 @@ import ( "github.com/lbryio/lbrytv/app/auth" "github.com/lbryio/lbrytv/app/proxy" - "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/internal/responses" @@ -29,35 +28,9 @@ const ( fileNameParam = "file_path" ) -// Publisher is responsible for sending data to lbrynet -// and should take file path, account ID and client query as a slice of bytes. -type Publisher interface { - Publish(filePath string, userID int, query []byte) []byte -} - -// LbrynetPublisher is an implementation of SDK publisher. -type LbrynetPublisher struct { - Router *sdkrouter.Router -} - -// Publish takes a file path, account ID and client JSON-RPC query, -// patches the query and sends it to the SDK for processing. -// Resulting response is then returned back as a slice of bytes. -func (p *LbrynetPublisher) Publish(filePath string, userID int, rawQuery []byte) []byte { - c := proxy.NewCaller(p.Router.GetServer(userID).Address, userID) - c.Preprocessor = func(q *proxy.Query) { - params := q.ParamsAsMap() - params[fileNameParam] = filePath - q.Request.Params = params - } - return c.CallRaw(rawQuery) -} - -// Handler glues HTTP uploads to the Publisher. +// Handler has path to save uploads to type Handler struct { - Publisher Publisher - UploadPath string - InternalAPIHost string + UploadPath string } // Handle is where HTTP upload is handled and passed on to Publisher. @@ -73,25 +46,39 @@ func (h Handler) Handle(w http.ResponseWriter, r *http.Request) { return } if !authResult.Authenticated() { - w.Write(proxy.NewForbiddenError(authResult.Err).JSON()) + w.Write(proxy.NewForbiddenError(authResult.Err()).JSON()) + return + } + if authResult.SDKAddress == "" { + w.Write(proxy.NewInternalError(errors.New("user does not have sdk address assigned")).JSON()) + logger.Log().Errorf("user %d does not have sdk address assigned", authResult.User().ID) return } - f, err := h.saveFile(r, authResult.User.ID) + f, err := h.saveFile(r, authResult.User().ID) if err != nil { logger.Log().Error(err) monitor.CaptureException(err) w.Write(proxy.NewInternalError(err).JSON()) return } + defer func() { + if err := os.Remove(f.Name()); err != nil { + monitor.CaptureException(err, map[string]string{"file_path": f.Name()}) + } + }() - response := h.Publisher.Publish(f.Name(), authResult.User.ID, []byte(r.FormValue(jsonRPCFieldName))) + w.Write(publish(authResult.SDKAddress, f.Name(), authResult.User().ID, []byte(r.FormValue(jsonRPCFieldName)))) +} - if err := os.Remove(f.Name()); err != nil { - monitor.CaptureException(err, map[string]string{"file_path": f.Name()}) +func publish(sdkAddress, filename string, userID int, rawQuery []byte) []byte { + c := proxy.NewCaller(sdkAddress, userID) + c.Preprocessor = func(q *proxy.Query) { + params := q.ParamsAsMap() + params[fileNameParam] = filename + q.Request.Params = params } - - w.Write(response) + return c.CallRaw(rawQuery) } // CanHandle checks if http.Request contains POSTed data in an accepted format. diff --git a/app/publish/publish_test.go b/app/publish/publish_test.go index 13446767..5f34eb29 100644 --- a/app/publish/publish_test.go +++ b/app/publish/publish_test.go @@ -8,13 +8,10 @@ import ( "path" "testing" - "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/config" - "github.com/lbryio/lbrytv/internal/responses" "github.com/lbryio/lbrytv/internal/storage" "github.com/lbryio/lbrytv/internal/test" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -27,9 +24,6 @@ func copyToDocker(t *testing.T, fileName string) { } func TestLbrynetPublisher(t *testing.T) { - // dummyUserID := 751365 - authToken := "zzz" - dbConfig := config.GetDatabase() params := storage.ConnParams{ Connection: dbConfig.Connection, @@ -40,27 +34,6 @@ func TestLbrynetPublisher(t *testing.T) { c.SetDefaultConnection() defer connCleanup() - reqChan := test.ReqChan() - ts := test.MockHTTPServer(reqChan) - defer ts.Close() - go func() { - req := <-reqChan - responses.AddJSONContentType(req.W) - ts.NextResponse <- fmt.Sprintf(`{ - "success": true, - "error": null, - "data": { - "user_id": %v, - "has_verified_email": true - } - }`, 751365) - }() - - rt := sdkrouter.New(config.GetLbrynetServers()) - p := &LbrynetPublisher{rt} - u, err := wallet.GetUserWithWallet(rt, ts.URL, authToken, "") - require.NoError(t, err) - data := []byte("test file") f, err := ioutil.TempFile(os.TempDir(), "*") require.NoError(t, err) @@ -92,7 +65,12 @@ func TestLbrynetPublisher(t *testing.T) { "id": 1567580184168 }`) - rawResp := p.Publish(path.Join("/storage", path.Base(f.Name())), u.ID, query) + userID := 751365 + server := test.RandServerAddress(t) + err = wallet.Create(server, userID) + require.NoError(t, err) + + rawResp := publish(server, path.Join("/storage", path.Base(f.Name())), userID, query) // This is all we can check for now without running on testnet or crediting some funds to the test account assert.Regexp(t, "Not enough funds to cover this transaction", string(rawResp)) diff --git a/app/publish/testing.go b/app/publish/testing.go index 02aa5ce3..717fd54d 100644 --- a/app/publish/testing.go +++ b/app/publish/testing.go @@ -35,8 +35,31 @@ func CreatePublishRequest(t *testing.T, data []byte) *http.Request { return req } +var expectedStreamCreateRequest = ` +{ + "id": 1567580184168, + "jsonrpc": "2.0", + "method": "stream_create", + "params": { + "name": "test", + "title": "test", + "description": "test description", + "bid": "0.10000000", + "languages": [ + "en" + ], + "tags": [], + "thumbnail_url": "http://smallmedia.com/thumbnail.jpg", + "license": "None", + "release_time": 1567580184, + "wallet_id": "%s", + "file_path": "%s" + } +}` + var expectedStreamCreateResponse = ` { + "id": 0, "jsonrpc": "2.0", "result": { "height": -2, @@ -158,25 +181,3 @@ var expectedStreamCreateResponse = ` } } ` - -var expectedStreamCreateRequest = ` -{ - "jsonrpc": "2.0", - "method": "stream_create", - "params": { - "name": "test", - "title": "test", - "description": "test description", - "bid": "0.10000000", - "languages": [ - "en" - ], - "tags": [], - "thumbnail_url": "http://smallmedia.com/thumbnail.jpg", - "license": "None", - "release_time": 1567580184, - "file_path": "/Users/silence/Desktop/tenor.gif" - }, - "id": 1567580184168 -} - ` diff --git a/app/sdkrouter/concurrency_test.go b/app/sdkrouter/concurrency_test.go index c0d524da..8ad6d351 100644 --- a/app/sdkrouter/concurrency_test.go +++ b/app/sdkrouter/concurrency_test.go @@ -47,13 +47,13 @@ func TestRouterConcurrency(t *testing.T) { case 0: r.RandomServer() r.GetAll() - r.GetServer(123) + r.LeastLoaded() case 1: r.GetAll() - r.GetServer(123) + r.LeastLoaded() r.RandomServer() case 2: - r.GetServer(123) + r.LeastLoaded() r.RandomServer() r.GetAll() } diff --git a/app/sdkrouter/sdkrouter.go b/app/sdkrouter/sdkrouter.go index a367dcaf..9bf36a8b 100644 --- a/app/sdkrouter/sdkrouter.go +++ b/app/sdkrouter/sdkrouter.go @@ -1,13 +1,9 @@ package sdkrouter import ( - "database/sql" - "errors" "fmt" "math/rand" - "regexp" "sort" - "strconv" "sync" "time" @@ -16,9 +12,6 @@ import ( "github.com/lbryio/lbrytv/models" ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" - - "github.com/volatiletech/sqlboiler/boil" - "github.com/volatiletech/sqlboiler/queries/qm" ) var logger = monitor.NewModuleLogger("sdkrouter") @@ -34,7 +27,6 @@ type Router struct { useDB bool lastLoaded time.Time - rpcClient *ljsonrpc.Client } func New(servers map[string]string) *Router { @@ -63,55 +55,6 @@ func (r *Router) GetAll() []*models.LbrynetServer { return r.servers } -func (r *Router) GetServer(userID int) *models.LbrynetServer { - r.reloadServersFromDB() - - var sdk *models.LbrynetServer - if userID == 0 { - sdk = r.RandomServer() - } else { - sdk = r.serverForUser(userID) - if sdk.Address == "" { - logger.Log().Errorf("user %d has server but server has no address.", userID) - sdk = r.RandomServer() - } - } - - logger.Log().Tracef("Using server %s for user %d", sdk.Address, userID) - return sdk -} - -func (r *Router) serverForUser(userID int) *models.LbrynetServer { - var user *models.User - var err error - if boil.GetDB() != nil { - user, err = models.Users(qm.Load(models.UserRels.LbrynetServer), models.UserWhere.ID.EQ(userID)).OneG() - if err != nil && !errors.Is(err, sql.ErrNoRows) { - logger.Log().Errorf("Error getting user %d from db: %v", userID, err.Error()) - } - } - - r.mu.RLock() - defer r.mu.RUnlock() - - if user == nil || user.R == nil || user.R.LbrynetServer == nil { - srv := r.servers[getServerForUserID(userID, len(r.servers))] - logger.Log().Debugf("User %d has no server assigned in db. Giving them server %s", userID, srv.Address) - return srv - } - - for _, s := range r.servers { - if s.ID == user.R.LbrynetServer.ID { - logger.Log().Debugf("User %d has server %s assigned in db", userID, s.Address) - return s - } - } - - srv := r.servers[getServerForUserID(userID, len(r.servers))] - logger.Log().Errorf("User %d has server assigned in db but is not in current servers list. Giving them server %s", userID, srv.Address) - return srv -} - func (r *Router) RandomServer() *models.LbrynetServer { r.reloadServersFromDB() r.mu.RLock() @@ -174,7 +117,7 @@ func (r *Router) updateLoadAndMetrics() { delete(r.load, server) r.loadMu.Unlock() metric.Set(-1.0) - // TODO: maybe mark this instance as unresponsive so new traffic is routed to other instances + // TODO: maybe mark this instance as unresponsive so new users are assigned to other instances } else { r.loadMu.Lock() r.load[server] = walletList.TotalPages @@ -196,7 +139,7 @@ func (r *Router) LeastLoaded() *models.LbrynetServer { if len(r.load) == 0 { // updateLoadAndMetrics() was never run, so return a random server - logger.Log().Debugf("LeastLoaded() called before updating load metrics. Returning random server.") + logger.Log().Warnf("LeastLoaded() called before updating load metrics. Returning random server.") return r.RandomServer() } @@ -210,25 +153,9 @@ func (r *Router) LeastLoaded() *models.LbrynetServer { return best } -func (r *Router) Client(userID int) *ljsonrpc.Client { - c := ljsonrpc.NewClient(r.GetServer(userID).Address) - //c.SetRPCTimeout(5 * time.Second) - return c -} - // WalletID formats user ID to use as an LbrynetServer wallet ID. func WalletID(userID int) string { - return fmt.Sprintf("lbrytv-id.%d.wallet", userID) -} - -func UserID(walletID string) int { - userID, err := strconv.ParseInt(regexp.MustCompile(`\d+`).FindString(walletID), 10, 64) - if err != nil { - return 0 - } - return int(userID) -} - -func getServerForUserID(userID, numServers int) int { - return userID % numServers + // warning: changing this template will require renaming the stored wallet files in lbrytv + const template = "lbrytv-id.%d.wallet" + return fmt.Sprintf(template, userID) } diff --git a/app/sdkrouter/sdkrouter_test.go b/app/sdkrouter/sdkrouter_test.go index 94c079b1..1c2e8ddf 100644 --- a/app/sdkrouter/sdkrouter_test.go +++ b/app/sdkrouter/sdkrouter_test.go @@ -31,26 +31,14 @@ func TestInitializeWithYML(t *testing.T) { } func TestServerOrder(t *testing.T) { - servers := map[string]string{ - // internally, servers will be sorted in lexical order by name - "b": "1", - "a": "0", - "d": "3", - "c": "2", - } - r := New(servers) - - for i := 1; i < 100; i++ { - server := r.GetServer(i).Address - assert.Equal(t, fmt.Sprintf("%d", i%len(servers)), server) - } + t.Skip("might bring this back when servers have an order") } func TestOverrideLbrynetDefaultConf(t *testing.T) { address := "http://space.com:1234" config.Override("LbrynetServers", map[string]string{"x": address}) defer config.RestoreOverridden() - server := New(config.GetLbrynetServers()).GetServer(343465345) + server := New(config.GetLbrynetServers()).RandomServer() assert.Equal(t, address, server.Address) } @@ -59,14 +47,10 @@ func TestOverrideLbrynetConf(t *testing.T) { config.Override("Lbrynet", address) config.Override("LbrynetServers", map[string]string{}) defer config.RestoreOverridden() - server := New(config.GetLbrynetServers()).GetServer(1343465345) + server := New(config.GetLbrynetServers()).RandomServer() assert.Equal(t, address, server.Address) } -func TestGetUserID(t *testing.T) { - assert.Equal(t, 1234235, UserID("sjdfkjhsdkjs.1234235.sdfsgf")) -} - func TestLeastLoaded(t *testing.T) { rpcServer := test.MockHTTPServer(nil) defer rpcServer.Close() diff --git a/app/wallet/wallet.go b/app/wallet/wallet.go index 3386ffba..3c629da8 100644 --- a/app/wallet/wallet.go +++ b/app/wallet/wallet.go @@ -10,6 +10,7 @@ import ( "github.com/lbryio/lbrytv/internal/lbrynet" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/models" + "github.com/volatiletech/sqlboiler/queries/qm" "github.com/lib/pq" pkgerrors "github.com/pkg/errors" @@ -48,8 +49,7 @@ func GetUserWithWallet(rt *sdkrouter.Router, internalAPIHost, token, metaRemoteI return nil, err } - if localUser.WalletID == "" { - log := logger.LogF(monitor.F{monitor.TokenF: token}) + if localUser.LbrynetServerID.IsZero() { err := assignSDKServerToUser(localUser, rt, log) if err != nil { return nil, err @@ -69,9 +69,9 @@ func getOrCreateLocalUser(remoteUserID int, log *logrus.Entry) (*models.User, er if err != nil { return nil, err } - } else if localUser.WalletID == "" { - // This scenario may happen for legacy users who are present in the database but don't have a wallet yet - log.Warnf("user %d doesn't have wallet ID set", localUser.ID) + } else if localUser.LbrynetServerID.IsZero() { + // This scenario may happen for legacy users who are present in the database but don't have a server assigned + log.Warnf("user %d found in db but doesn't have sdk assigned", localUser.ID) } return localUser, nil @@ -82,18 +82,17 @@ func assignSDKServerToUser(user *models.User, router *sdkrouter.Router, log *log if server.ID > 0 { // Ensure server is from DB user.LbrynetServerID.SetValid(server.ID) } else { - // THIS SERVER CAME FROM A CONFIG FILE (prolly for testing) + // THIS SERVER CAME FROM A CONFIG FILE (prolly during testing) // TODO: handle this case better - //return fmt.Errorf("user %d is getting a wallet server with no ID", user.ID) + log.Warnf("user %d is getting an sdk with no ID. could happen if servers came from config file", user.ID) } - walletID, err := Create(server.Address, user.ID) + err := Create(server.Address, user.ID) if err != nil { return err } log.Infof("assigning sdk %s to user %d", server.Address, user.ID) - user.WalletID = walletID _, err = user.UpdateG(boil.Infer()) return err } @@ -122,41 +121,43 @@ func createDBUser(id int) (*models.User, error) { } func getDBUser(id int) (*models.User, error) { - return models.Users(models.UserWhere.ID.EQ(id)).OneG() + return models.Users( + models.UserWhere.ID.EQ(id), + qm.Load(models.UserRels.LbrynetServer), + ).OneG() } // Create creates a wallet on an sdk that can be immediately used in subsequent commands. // It can recover from errors like existing wallets, but if a wallet is known to exist // (eg. a wallet ID stored in the database already), loadWallet() should be called instead. -func Create(serverAddress string, userID int) (string, error) { - wallet, err := createWallet(serverAddress, userID) +func Create(serverAddress string, userID int) error { + err := createWallet(serverAddress, userID) if err == nil { - return wallet.ID, nil + return nil } - walletID := sdkrouter.WalletID(userID) log := logger.LogF(monitor.F{"user_id": userID, "sdk": serverAddress}) if errors.Is(err, lbrynet.ErrWalletExists) { log.Warn(err.Error()) - return walletID, nil + return nil } if errors.Is(err, lbrynet.ErrWalletNeedsLoading) { log.Info(err.Error()) - wallet, err = loadWallet(serverAddress, userID) + err = loadWallet(serverAddress, userID) if err != nil { if errors.Is(err, lbrynet.ErrWalletAlreadyLoaded) { log.Info(err.Error()) - return walletID, nil + return nil } - return "", err + return err } - return wallet.ID, nil + return nil } log.Errorf("don't know how to recover from error: %v", err) - return "", err + return err } // createWallet creates a new wallet on the LbrynetServer. @@ -169,27 +170,27 @@ func Create(serverAddress string, userID int) (string, error) { // if errors.Is(err, lbrynet.WalletNeedsLoading) { // // loadWallet() needs to be called before the wallet can be used // } -func createWallet(addr string, userID int) (*ljsonrpc.Wallet, error) { - wallet, err := ljsonrpc.NewClient(addr).WalletCreate(sdkrouter.WalletID(userID), &ljsonrpc.WalletCreateOpts{ +func createWallet(addr string, userID int) error { + _, err := ljsonrpc.NewClient(addr).WalletCreate(sdkrouter.WalletID(userID), &ljsonrpc.WalletCreateOpts{ SkipOnStartup: true, CreateAccount: true, SingleKey: true}) if err != nil { - return nil, lbrynet.NewWalletError(userID, err) + return lbrynet.NewWalletError(userID, err) } logger.LogF(monitor.F{"user_id": userID, "sdk": addr}).Info("wallet created") - return wallet, nil + return nil } // loadWallet loads an existing wallet in the LbrynetServer. // May return errors: // WalletAlreadyLoaded - wallet is already loaded and operational // WalletNotFound - wallet file does not exist and won't be loaded. -func loadWallet(addr string, userID int) (*ljsonrpc.Wallet, error) { - wallet, err := ljsonrpc.NewClient(addr).WalletAdd(sdkrouter.WalletID(userID)) +func loadWallet(addr string, userID int) error { + _, err := ljsonrpc.NewClient(addr).WalletAdd(sdkrouter.WalletID(userID)) if err != nil { - return nil, lbrynet.NewWalletError(userID, err) + return lbrynet.NewWalletError(userID, err) } logger.LogF(monitor.F{"user_id": userID, "sdk": addr}).Info("wallet loaded") - return wallet, nil + return nil } // UnloadWallet unloads an existing wallet from the LbrynetServer. diff --git a/app/wallet/wallet_test.go b/app/wallet/wallet_test.go index c863f9f7..64e2f3b8 100644 --- a/app/wallet/wallet_test.go +++ b/app/wallet/wallet_test.go @@ -16,7 +16,6 @@ import ( "github.com/lbryio/lbrytv/internal/storage" "github.com/lbryio/lbrytv/internal/test" "github.com/lbryio/lbrytv/models" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -44,7 +43,7 @@ func setupDBTables() { storage.Conn.Truncate([]string{"users"}) } -func dummyAPI(rt *sdkrouter.Router) (string, func()) { +func dummyAPI(sdkAddress string) (string, func()) { reqChan := test.ReqChan() ts := test.MockHTTPServer(reqChan) go func() { @@ -55,7 +54,7 @@ func dummyAPI(rt *sdkrouter.Router) (string, func()) { "success": true, "error": null, "data": { - "user_id": %v, + "user_id": %d, "has_verified_email": true } }`, dummyUserID) @@ -64,29 +63,48 @@ func dummyAPI(rt *sdkrouter.Router) (string, func()) { return ts.URL, func() { ts.Close() - UnloadWallet(rt.GetServer(dummyUserID).Address, dummyUserID) + UnloadWallet(sdkAddress, dummyUserID) } } func TestWalletServiceRetrieveNewUser(t *testing.T) { - rt := sdkrouter.New(config.GetLbrynetServers()) + srv := test.RandServerAddress(t) + rt := sdkrouter.New(map[string]string{"a": srv}) setupDBTables() - url, cleanup := dummyAPI(rt) + url, cleanup := dummyAPI(srv) defer cleanup() - wid := sdkrouter.WalletID(dummyUserID) u, err := GetUserWithWallet(rt, url, "abc", "") require.NoError(t, err, errors.Unwrap(err)) require.NotNil(t, u) - require.Equal(t, wid, u.WalletID) count, err := models.Users(models.UserWhere.ID.EQ(u.ID)).CountG() require.NoError(t, err) assert.EqualValues(t, 1, count) + assert.True(t, u.LbrynetServerID.IsZero()) // because the server came from a config, it should not have an id set - u, err = GetUserWithWallet(rt, url, "abc", "") + // now assign the user a new server thats set in the db + // rand.Intn(99999), + sdk := &models.LbrynetServer{ + Name: "testing", + Address: "test.test.test.test", + } + err = u.SetLbrynetServerG(true, sdk) + require.NoError(t, err) + require.NotEqual(t, 0, sdk.ID) + require.Equal(t, u.LbrynetServerID.Int, sdk.ID) + + // now fetch it all back from the db + + u2, err := GetUserWithWallet(rt, url, "abc", "") require.NoError(t, err, errors.Unwrap(err)) - require.Equal(t, wid, u.WalletID) + require.NotNil(t, u2) + + sdk2, err := u.LbrynetServer().OneG() + require.NoError(t, err) + require.Equal(t, sdk.ID, sdk2.ID) + require.Equal(t, sdk.Address, sdk2.Address) + require.Equal(t, u.LbrynetServerID.Int, sdk2.ID) } func TestWalletServiceRetrieveNonexistentUser(t *testing.T) { @@ -108,9 +126,10 @@ func TestWalletServiceRetrieveNonexistentUser(t *testing.T) { } func TestWalletServiceRetrieveExistingUser(t *testing.T) { - rt := sdkrouter.New(config.GetLbrynetServers()) + srv := test.RandServerAddress(t) + rt := sdkrouter.New(map[string]string{"a": srv}) setupDBTables() - url, cleanup := dummyAPI(rt) + url, cleanup := dummyAPI(srv) defer cleanup() u, err := GetUserWithWallet(rt, url, "abc", "") @@ -126,7 +145,7 @@ func TestWalletServiceRetrieveExistingUser(t *testing.T) { assert.EqualValues(t, 1, count) } -func TestWalletServiceRetrieveExistingUserMissingWalletID(t *testing.T) { +func TestGetUserWithWallet_ExistingUserWithSDKGetsAssignedOneOnRetrieve(t *testing.T) { setupDBTables() userID := int(rand.Int31()) @@ -141,7 +160,7 @@ func TestWalletServiceRetrieveExistingUserMissingWalletID(t *testing.T) { "success": true, "error": null, "data": { - "user_id": %v, + "user_id": %d, "has_verified_email": true } }`, userID) @@ -154,7 +173,7 @@ func TestWalletServiceRetrieveExistingUserMissingWalletID(t *testing.T) { u, err = GetUserWithWallet(rt, ts.URL, "abc", "") require.NoError(t, err) - assert.NotEqual(t, "", u.WalletID) + assert.NotEqual(t, "", u.LbrynetServerID) } func TestWalletServiceRetrieveNoVerifiedEmail(t *testing.T) { @@ -230,21 +249,23 @@ func BenchmarkWalletCommands(b *testing.B) { b.StopTimer() } +func TestCreate_CorrectWalletID(t *testing.T) { + // test that calling Create() sends the correct wallet id to the server +} + func TestInitializeWallet(t *testing.T) { rand.Seed(time.Now().UnixNano()) userID := rand.Int() addr := test.RandServerAddress(t) - walletID, err := Create(addr, userID) + err := Create(addr, userID) require.NoError(t, err) - assert.Equal(t, walletID, sdkrouter.WalletID(userID)) err = UnloadWallet(addr, userID) require.NoError(t, err) - walletID, err = Create(addr, userID) + err = Create(addr, userID) require.NoError(t, err) - assert.Equal(t, walletID, sdkrouter.WalletID(userID)) } func TestCreateWalletLoadWallet(t *testing.T) { @@ -252,18 +273,16 @@ func TestCreateWalletLoadWallet(t *testing.T) { userID := rand.Int() addr := test.RandServerAddress(t) - wallet, err := createWallet(addr, userID) + err := createWallet(addr, userID) require.NoError(t, err) - assert.Equal(t, wallet.ID, sdkrouter.WalletID(userID)) - wallet, err = createWallet(addr, userID) + err = createWallet(addr, userID) require.NotNil(t, err) assert.True(t, errors.Is(err, lbrynet.ErrWalletExists)) err = UnloadWallet(addr, userID) require.NoError(t, err) - wallet, err = loadWallet(addr, userID) + err = loadWallet(addr, userID) require.NoError(t, err) - assert.Equal(t, wallet.ID, sdkrouter.WalletID(userID)) } diff --git a/go.mod b/go.mod index 41973692..b236e471 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/lbryio/lbry.go/v2 v2.4.4 github.com/lbryio/reflector.go v1.1.1 github.com/lib/pq v1.2.0 + github.com/nsf/jsondiff v0.0.0-20190712045011-8443391ee9b6 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pelletier/go-toml v1.6.0 // indirect github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 // indirect diff --git a/go.sum b/go.sum index 1fb18988..f4b46f79 100644 --- a/go.sum +++ b/go.sum @@ -366,6 +366,8 @@ github.com/nlopes/slack v0.5.0 h1:NbIae8Kd0NpqaEI3iUrsuS0KbcEDhzhc939jLW5fNm0= github.com/nlopes/slack v0.5.0/go.mod h1:jVI4BBK3lSktibKahxBF74txcK2vyvkza1z/+rRnVAM= github.com/nlopes/slack v0.6.0 h1:jt0jxVQGhssx1Ib7naAOZEZcGdtIhTzkP0nopK0AsRA= github.com/nlopes/slack v0.6.0/go.mod h1:JzQ9m3PMAqcpeCam7UaHSuBuupz7CmpjehYMayT6YOk= +github.com/nsf/jsondiff v0.0.0-20190712045011-8443391ee9b6 h1:qsqscDgSJy+HqgMTR+3NwjYJBbp1+honwDsszLoS+pA= +github.com/nsf/jsondiff v0.0.0-20190712045011-8443391ee9b6/go.mod h1:uFMI8w+ref4v2r9jz+c9i1IfIttS/OkmLfrk1jne5hs= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= diff --git a/internal/responses/responses.go b/internal/responses/responses.go index 5e16eaf0..e7ff00c9 100644 --- a/internal/responses/responses.go +++ b/internal/responses/responses.go @@ -4,6 +4,8 @@ import ( "net/http" ) +// this is the message to show when authentication info is required but was not provided in the request +// this is NOT the message for when auth info is provided but is not correct const AuthRequiredErrorMessage = "authentication required" // AddJSONContentType prepares HTTP response writer for JSON content-type. diff --git a/internal/test/test.go b/internal/test/test.go index 75169163..8a675a6f 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -1,14 +1,17 @@ package test import ( + "bytes" "encoding/json" "fmt" "io/ioutil" "net/http" "net/http/httptest" + "regexp" "testing" "github.com/lbryio/lbrytv/config" + "github.com/nsf/jsondiff" "github.com/ybbus/jsonrpc" ) @@ -85,3 +88,70 @@ func RandServerAddress(t *testing.T) string { t.Fatal("no lbrynet servers configured") return "" } + +// JsonCompact removes insignificant space characters from a JSON string +// It helps compare JSON strings without worrying about whitespace differences +func JsonCompact(jsonStr string) string { + dst := &bytes.Buffer{} + err := json.Compact(dst, []byte(jsonStr)) + if err != nil { + panic(err) + } + return dst.String() +} + +// JsonCompare compares two json strings. +func jsonCompare(a, b []byte) (bool, string) { + opts := jsondiff.DefaultConsoleOptions() + diff, str := jsondiff.Compare(a, b, &opts) + return diff == jsondiff.FullMatch, str +} + +// assert.Equal for JSON - more accurate comparison, pretty diff +func AssertJsonEqual(t *testing.T, expected, actual interface{}, msgAndArgs ...interface{}) bool { + t.Helper() + same, diff := jsonCompare(toBytes(expected), toBytes(actual)) + if same { + return true + } + + indent := "\t\t" + diffIndented := regexp.MustCompile("(?m)^").ReplaceAll([]byte(diff), []byte("\t"+indent))[len(indent)+1:] + tmpl := "\n\tError:" + indent + "JSON not equal\n\tDiff:" + indent + "%s" + msg := messageFromMsgAndArgs(msgAndArgs...) + if len(msg) > 0 { + t.Errorf(tmpl+"\n\tMessages:"+indent+"%s", diffIndented, msg) + } else { + t.Errorf(tmpl, diffIndented) + } + return false +} + +func toBytes(v interface{}) []byte { + switch s := v.(type) { + case string: + return []byte(s) + case []byte: + return s + default: + panic(fmt.Sprintf("cannot convert %T to byte slice", v)) + } +} + +// copied from assert.Fail() +func messageFromMsgAndArgs(msgAndArgs ...interface{}) string { + if len(msgAndArgs) == 0 || msgAndArgs == nil { + return "" + } + if len(msgAndArgs) == 1 { + msg := msgAndArgs[0] + if msgAsStr, ok := msg.(string); ok { + return msgAsStr + } + return fmt.Sprintf("%+v", msg) + } + if len(msgAndArgs) > 1 { + return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...) + } + return "" +} From d46134d7b165e227271fb4a97680edd6c32a59c9 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Thu, 16 Apr 2020 09:52:17 -0400 Subject: [PATCH 09/18] add tests for recovery handler --- api/routes.go | 2 +- api/routes_test.go | 52 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/api/routes.go b/api/routes.go index 7b5acd65..92b01006 100644 --- a/api/routes.go +++ b/api/routes.go @@ -91,7 +91,7 @@ func recoveryHandler(next http.Handler) http.Handler { Error: &jsonrpc.RPCError{ Code: -1, Message: recovered.Error(), - Data: stack, + Data: string(stack), }, }) w.Write(rsp) diff --git a/api/routes_test.go b/api/routes_test.go index 8f52e7d7..36fca749 100644 --- a/api/routes_test.go +++ b/api/routes_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "encoding/json" "net/http" "net/http/httptest" "testing" @@ -12,6 +13,7 @@ import ( "github.com/lbryio/lbrytv/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/ybbus/jsonrpc" ) func TestRoutesProxy(t *testing.T) { @@ -64,3 +66,53 @@ func TestRoutesOptions(t *testing.T) { rr.Result().Header.Get("Access-Control-Allow-Headers"), ) } + +func TestRecoveryHandler_Panic(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("xoxox") + }) + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, "/", &bytes.Buffer{}) + require.NoError(t, err) + logger.Disable() + assert.NotPanics(t, func() { + recoveryHandler(h).ServeHTTP(rr, r) + }) + var res jsonrpc.RPCResponse + err = json.Unmarshal(rr.Body.Bytes(), &res) + require.NoError(t, err) + require.NotNil(t, res.Error) + assert.Contains(t, res.Error.Message, "xoxox") +} + +func TestRecoveryHandler_NoPanic(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("no panic recovery")) + }) + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, "/", &bytes.Buffer{}) + require.NoError(t, err) + assert.NotPanics(t, func() { + recoveryHandler(h).ServeHTTP(rr, r) + }) + assert.Equal(t, rr.Body.String(), "no panic recovery") + +} + +func TestRecoveryHandler_RecoveredPanic(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if r := recover(); r != nil { + w.Write([]byte("panic recovered in here")) + } + }() + panic("xoxoxo") + }) + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, "/", &bytes.Buffer{}) + require.NoError(t, err) + assert.NotPanics(t, func() { + recoveryHandler(h).ServeHTTP(rr, r) + }) + assert.Equal(t, rr.Body.String(), "panic recovered in here") +} From 2d4c559d97b24193379c80c0cc6b6d8aa491dcb1 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Thu, 16 Apr 2020 10:15:43 -0400 Subject: [PATCH 10/18] minor --- app/auth/auth.go | 7 ++++--- app/wallet/wallet.go | 15 ++++++--------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/app/auth/auth.go b/app/auth/auth.go index 60d9486f..7cd428b3 100644 --- a/app/auth/auth.go +++ b/app/auth/auth.go @@ -25,11 +25,12 @@ func FromRequest(r *http.Request) Result { return v.(Result) } -// Provider gets a user by hitting internal-api with the provided auth token -// and matching the response to a local user. -// NOTE: The retrieved user must come with a wallet that's created and ready to use. +// Provider tries to authenticate using the provided auth token type Provider func(token, metaRemoteIP string) Result +// WalletAndInternalAPIProvider auths a user by hitting internal-api with the auth token +// and matching the response to a local user. If auth is successful, the user will have a +// lbrynet server assigned and a wallet that's created and ready to use. func WalletAndInternalAPIProvider(rt *sdkrouter.Router, internalAPIHost string) Provider { return func(token, metaRemoteIP string) Result { user, err := wallet.GetUserWithWallet(rt, internalAPIHost, token, metaRemoteIP) diff --git a/app/wallet/wallet.go b/app/wallet/wallet.go index 3c629da8..04d0ffd6 100644 --- a/app/wallet/wallet.go +++ b/app/wallet/wallet.go @@ -80,21 +80,18 @@ func getOrCreateLocalUser(remoteUserID int, log *logrus.Entry) (*models.User, er func assignSDKServerToUser(user *models.User, router *sdkrouter.Router, log *logrus.Entry) error { server := router.LeastLoaded() if server.ID > 0 { // Ensure server is from DB - user.LbrynetServerID.SetValid(server.ID) + log.Infof("assigning sdk %s to user %d", server.Address, user.ID) + err := user.SetLbrynetServerG(false, server) + if err != nil { + return err + } } else { // THIS SERVER CAME FROM A CONFIG FILE (prolly during testing) // TODO: handle this case better log.Warnf("user %d is getting an sdk with no ID. could happen if servers came from config file", user.ID) } - err := Create(server.Address, user.ID) - if err != nil { - return err - } - - log.Infof("assigning sdk %s to user %d", server.Address, user.ID) - _, err = user.UpdateG(boil.Infer()) - return err + return Create(server.Address, user.ID) } func createDBUser(id int) (*models.User, error) { From b8af1e9a99756d3c1a659c4ff64fce85d8ea2ad9 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Thu, 16 Apr 2020 15:18:51 -0400 Subject: [PATCH 11/18] add user info to /internal/status --- api/routes.go | 24 +++++++++++++--- app/auth/auth.go | 2 +- app/proxy/accounts_test.go | 2 +- go.mod | 1 + go.sum | 2 ++ internal/status/status.go | 59 +++++++++++++++++++++++++------------- 6 files changed, 64 insertions(+), 26 deletions(-) diff --git a/api/routes.go b/api/routes.go index 92b01006..f9a925f2 100644 --- a/api/routes.go +++ b/api/routes.go @@ -36,10 +36,14 @@ func InstallRoutes(r *mux.Router, sdkRouter *sdkrouter.Router) { http.Redirect(w, req, config.GetProjectURL(), http.StatusSeeOther) }) + authProvider := auth.NewWalletAndInternalAPIProvider(sdkRouter, config.GetInternalAPIHost()) + middlewareStack := middlewares( + sdkrouter.Middleware(sdkRouter), + auth.Middleware(authProvider), + ) + v1Router := r.PathPrefix("/api/v1").Subrouter() - v1Router.Use(sdkrouter.Middleware(sdkRouter)) - authProvider := auth.WalletAndInternalAPIProvider(sdkRouter, config.GetInternalAPIHost()) - v1Router.Use(auth.Middleware(authProvider)) + v1Router.Use(middlewareStack) v1Router.HandleFunc("/proxy", proxy.HandleCORS).Methods(http.MethodOptions) v1Router.HandleFunc("/proxy", upHandler.Handle).MatcherFunc(upHandler.CanHandle) v1Router.HandleFunc("/proxy", proxy.Handle) @@ -47,10 +51,22 @@ func InstallRoutes(r *mux.Router, sdkRouter *sdkrouter.Router) { internalRouter := r.PathPrefix("/internal").Subrouter() internalRouter.Handle("/metrics", promhttp.Handler()) - internalRouter.HandleFunc("/status", sdkrouter.AddToRequest(sdkRouter, status.GetStatus)) + internalRouter.Handle("/status", middlewareStack(http.HandlerFunc(status.GetStatus))) internalRouter.HandleFunc("/whoami", status.WhoAMI) } +// applies several middleware in order +func middlewares(mws ...mux.MiddlewareFunc) mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + for _, mw := range mws { + next = mw(next) + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + }) + } +} + func methodTimer(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() diff --git a/app/auth/auth.go b/app/auth/auth.go index 7cd428b3..e252cc5f 100644 --- a/app/auth/auth.go +++ b/app/auth/auth.go @@ -31,7 +31,7 @@ type Provider func(token, metaRemoteIP string) Result // WalletAndInternalAPIProvider auths a user by hitting internal-api with the auth token // and matching the response to a local user. If auth is successful, the user will have a // lbrynet server assigned and a wallet that's created and ready to use. -func WalletAndInternalAPIProvider(rt *sdkrouter.Router, internalAPIHost string) Provider { +func NewWalletAndInternalAPIProvider(rt *sdkrouter.Router, internalAPIHost string) Provider { return func(token, metaRemoteIP string) Result { user, err := wallet.GetUserWithWallet(rt, internalAPIHost, token, metaRemoteIP) res := NewResult(user, err) diff --git a/app/proxy/accounts_test.go b/app/proxy/accounts_test.go index 223f2580..51b65fe3 100644 --- a/app/proxy/accounts_test.go +++ b/app/proxy/accounts_test.go @@ -39,7 +39,7 @@ func TestWithWrongAuthToken(t *testing.T) { rr := httptest.NewRecorder() rt := sdkrouter.New(config.GetLbrynetServers()) - handler := sdkrouter.Middleware(rt)(auth.Middleware(auth.WalletAndInternalAPIProvider(rt, ts.URL))(http.HandlerFunc(Handle))) + handler := sdkrouter.Middleware(rt)(auth.Middleware(auth.NewWalletAndInternalAPIProvider(rt, ts.URL))(http.HandlerFunc(Handle))) handler.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) diff --git a/go.mod b/go.mod index b236e471..facfa49f 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/gobuffalo/packr/v2 v2.7.1 github.com/gofrs/uuid v3.2.0+incompatible // indirect github.com/gorilla/mux v1.7.3 + github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a github.com/jinzhu/gorm v1.9.9 github.com/jmoiron/sqlx v1.2.0 github.com/kat-co/vala v0.0.0-20170210184112-42e1d8b61f12 diff --git a/go.sum b/go.sum index f4b46f79..b9c5b883 100644 --- a/go.sum +++ b/go.sum @@ -240,6 +240,8 @@ github.com/iris-contrib/go.uuid v2.0.0+incompatible/go.mod h1:iz2lgM/1UnEf1kP0L/ github.com/iris-contrib/i18n v0.0.0-20171121225848-987a633949d0/go.mod h1:pMCz62A0xJL6I+umB2YTlFRwWXaDFA0jy+5HzGiJjqI= github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a h1:zPPuIq2jAWWPTrGt70eK/BSch+gFAGrNzecsoENgu2o= +github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a/go.mod h1:yL958EeXv8Ylng6IfnvG4oflryUi3vgA3xPs9hmII1s= github.com/jinzhu/gorm v1.9.9 h1:Gc8bP20O+vroFUzZEXA1r7vNGQZGQ+RKgOnriuNF3ds= github.com/jinzhu/gorm v1.9.9/go.mod h1:Kh6hTsSGffh4ui079FHrR5Gg+5D0hgihqDcsDN2BBJY= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= diff --git a/internal/status/status.go b/internal/status/status.go index 926d20c8..806a1397 100644 --- a/internal/status/status.go +++ b/internal/status/status.go @@ -6,8 +6,12 @@ import ( "net/http" "time" + "github.com/lbryio/lbrytv/app/auth" "github.com/lbryio/lbrytv/app/sdkrouter" "github.com/lbryio/lbrytv/internal/monitor" + "github.com/lbryio/lbrytv/internal/responses" + + "github.com/jinzhu/copier" ) var logger = monitor.NewModuleLogger("status") @@ -21,15 +25,15 @@ var PlayerServers = []string{ } var ( - cachedResponse *statusResponse = nil + cachedResponse *statusResponse lastUpdate time.Time ) const ( - StatusOK = "ok" - StatusNotReady = "not_ready" - StatusOffline = "offline" - StatusFailing = "failing" + statusOK = "ok" + statusNotReady = "not_ready" + statusOffline = "offline" + statusFailing = "failing" statusCacheValidity = 120 * time.Second ) @@ -43,38 +47,38 @@ type statusResponse map[string]interface{} func GetStatus(w http.ResponseWriter, req *http.Request) { respStatus := http.StatusOK - var response *statusResponse + var response statusResponse if cachedResponse != nil && lastUpdate.After(time.Now().Add(statusCacheValidity)) { - response = cachedResponse + //response = *cachedResponse + copier.Copy(&response, cachedResponse) } else { services := map[string]ServerList{ "lbrynet": {}, "player": {}, } - response = &statusResponse{ + response = statusResponse{ "timestamp": fmt.Sprintf("%v", time.Now().UTC()), "services": services, - "general_state": StatusOK, + "general_state": statusOK, } failureDetected := false sdks := sdkrouter.FromRequest(req).GetAll() for _, s := range sdks { - srv := ServerItem{Address: s.Address, Status: StatusOK} - services["lbrynet"] = append(services["lbrynet"], srv) + services["lbrynet"] = append(services["lbrynet"], ServerItem{Address: s.Address, Status: statusOK}) } for _, ps := range PlayerServers { r, err := http.Get(ps) - srv := ServerItem{Address: ps, Status: StatusOK} + srv := ServerItem{Address: ps, Status: statusOK} if err != nil { srv.Error = fmt.Sprintf("%v", err) - srv.Status = StatusOffline + srv.Status = statusOffline respStatus = http.StatusServiceUnavailable failureDetected = true } else if r.StatusCode != http.StatusNotFound { - srv.Status = StatusNotReady + srv.Status = statusNotReady srv.Error = fmt.Sprintf("http status %v", r.StatusCode) respStatus = http.StatusServiceUnavailable failureDetected = true @@ -82,14 +86,26 @@ func GetStatus(w http.ResponseWriter, req *http.Request) { services["player"] = append(services["player"], srv) } if failureDetected { - (*response)["general_state"] = StatusFailing + response["general_state"] = statusFailing } - cachedResponse = response + cachedResponse = &response lastUpdate = time.Now() } - w.Header().Add("content-type", "application/json; charset=utf-8") + + authResult := auth.FromRequest(req) + if authResult.Authenticated() { + response["user"] = map[string]interface{}{ + "user_id": authResult.User().ID, + "assigned_sdk": authResult.SDKAddress, + } + } + + responses.AddJSONContentType(w) w.WriteHeader(respStatus) - respByte, _ := json.MarshalIndent(&response, "", " ") + respByte, err := json.MarshalIndent(response, "", " ") + if err != nil { + logger.Log().Error(err) + } w.Write(respByte) } @@ -100,7 +116,10 @@ func WhoAMI(w http.ResponseWriter, req *http.Request) { "X-Real-Ip": req.Header.Get("X-Real-Ip"), } - w.Header().Add("content-type", "application/json; charset=utf-8") - respByte, _ := json.MarshalIndent(&details, "", " ") + responses.AddJSONContentType(w) + respByte, err := json.MarshalIndent(&details, "", " ") + if err != nil { + logger.Log().Error(err) + } w.Write(respByte) } From 797d0245ebcb19a49c0f9ce50a955a5729b971ed Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Thu, 16 Apr 2020 15:57:47 -0400 Subject: [PATCH 12/18] added trace logging for sdkrouter locks --- app/sdkrouter/sdkrouter.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/app/sdkrouter/sdkrouter.go b/app/sdkrouter/sdkrouter.go index 9bf36a8b..b336547b 100644 --- a/app/sdkrouter/sdkrouter.go +++ b/app/sdkrouter/sdkrouter.go @@ -50,14 +50,18 @@ func New(servers map[string]string) *Router { func (r *Router) GetAll() []*models.LbrynetServer { r.reloadServersFromDB() + logger.WithField("lock", "mu").Trace("waiting for read lock in GetAll") r.mu.RLock() + logger.WithField("lock", "mu").Trace("got read lock in GetAll") defer r.mu.RUnlock() return r.servers } func (r *Router) RandomServer() *models.LbrynetServer { r.reloadServersFromDB() + logger.WithField("lock", "mu").Trace("waiting for read lock in RandomServer") r.mu.RLock() + logger.WithField("lock", "mu").Trace("got read lock in RandomServer") defer r.mu.RUnlock() return r.servers[rand.Intn(len(r.servers))] } @@ -86,7 +90,9 @@ func (r *Router) setServers(servers []*models.LbrynetServer) { // we do this partially to make sure that ids are assigned to servers more consistently, // and partially to make tests consistent (since Go maps are not ordered) sort.Slice(servers, func(i, j int) bool { return servers[i].Name < servers[j].Name }) + logger.WithField("lock", "mu").Trace("waiting for write lock in setServers") r.mu.Lock() + logger.WithField("lock", "mu").Trace("got write lock in setServers") defer r.mu.Unlock() r.servers = servers logger.Log().Debugf("updated server list to %d servers", len(r.servers)) @@ -113,13 +119,17 @@ func (r *Router) updateLoadAndMetrics() { walletList, err := ljsonrpc.NewClient(server.Address).WalletList("", 1, 1) if err != nil { logger.Log().Errorf("lbrynet instance %s is not responding: %v", server.Address, err) + logger.WithField("lock", "loadMu").Trace("waiting for write lock in updateLoadAndMetrics 1") r.loadMu.Lock() + logger.WithField("lock", "loadMu").Trace("got write lock in updateLoadAndMetrics 1") delete(r.load, server) r.loadMu.Unlock() metric.Set(-1.0) // TODO: maybe mark this instance as unresponsive so new users are assigned to other instances } else { + logger.WithField("lock", "loadMu").Trace("waiting for write lock in updateLoadAndMetrics 2") r.loadMu.Lock() + logger.WithField("lock", "loadMu").Trace("got write lock in updateLoadAndMetrics 2") r.load[server] = walletList.TotalPages r.loadMu.Unlock() metric.Set(float64(walletList.TotalPages)) @@ -134,7 +144,9 @@ func (r *Router) LeastLoaded() *models.LbrynetServer { var best *models.LbrynetServer var min uint64 + logger.WithField("lock", "loadMu").Trace("waiting for read lock in LeastLoaded") r.loadMu.RLock() + logger.WithField("lock", "loadMu").Trace("got read lock in LeastLoaded") defer r.loadMu.RUnlock() if len(r.load) == 0 { From 84157a6c40379e54b0e1304bb9c1aa09b6462525 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Mon, 20 Apr 2020 11:14:30 -0400 Subject: [PATCH 13/18] some fixes --- api/routes.go | 2 +- app/auth/auth.go | 9 +- app/proxy/accounts_test.go | 2 +- app/proxy/cache.go | 31 ++--- app/proxy/cache_test.go | 16 +-- app/proxy/const.go | 1 - app/proxy/errors.go | 2 +- app/proxy/handlers.go | 3 + app/proxy/{caller.go => proxy.go} | 39 ++++-- app/proxy/{caller_test.go => proxy_test.go} | 4 +- app/proxy/query.go | 8 +- app/publish/publish.go | 3 +- app/sdkrouter/sdkrouter.go | 25 ++-- app/wallet/wallet.go | 12 +- internal/ip/ip.go | 2 +- internal/monitor/middleware.go | 11 +- internal/monitor/module_logger.go | 42 ++---- internal/monitor/monitor.go | 104 ++------------- internal/monitor/monitor_test.go | 134 ++++++++++---------- internal/storage/conn.go | 5 +- internal/storage/maintenance.go | 11 +- internal/test/test.go | 1 + internal/test/test_test.go | 2 +- server/server.go | 7 +- 24 files changed, 194 insertions(+), 282 deletions(-) rename app/proxy/{caller.go => proxy.go} (84%) rename app/proxy/{caller_test.go => proxy_test.go} (98%) diff --git a/api/routes.go b/api/routes.go index f9a925f2..8633714e 100644 --- a/api/routes.go +++ b/api/routes.go @@ -36,7 +36,7 @@ func InstallRoutes(r *mux.Router, sdkRouter *sdkrouter.Router) { http.Redirect(w, req, config.GetProjectURL(), http.StatusSeeOther) }) - authProvider := auth.NewWalletAndInternalAPIProvider(sdkRouter, config.GetInternalAPIHost()) + authProvider := auth.NewIAPIProvider(sdkRouter, config.GetInternalAPIHost()) middlewareStack := middlewares( sdkrouter.Middleware(sdkRouter), auth.Middleware(authProvider), diff --git a/app/auth/auth.go b/app/auth/auth.go index e252cc5f..4051211c 100644 --- a/app/auth/auth.go +++ b/app/auth/auth.go @@ -9,6 +9,7 @@ import ( "github.com/lbryio/lbrytv/internal/ip" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/models" + "github.com/sirupsen/logrus" "github.com/gorilla/mux" ) @@ -20,7 +21,7 @@ const ContextKey = "user" func FromRequest(r *http.Request) Result { v := r.Context().Value(ContextKey) if v == nil { - panic("Auth middleware was not applied") + panic("auth.Middleware is required") } return v.(Result) } @@ -28,10 +29,10 @@ func FromRequest(r *http.Request) Result { // Provider tries to authenticate using the provided auth token type Provider func(token, metaRemoteIP string) Result -// WalletAndInternalAPIProvider auths a user by hitting internal-api with the auth token +// NewIAPIProvider authenticates a user by hitting internal-api with the auth token // and matching the response to a local user. If auth is successful, the user will have a // lbrynet server assigned and a wallet that's created and ready to use. -func NewWalletAndInternalAPIProvider(rt *sdkrouter.Router, internalAPIHost string) Provider { +func NewIAPIProvider(rt *sdkrouter.Router, internalAPIHost string) Provider { return func(token, metaRemoteIP string) Result { user, err := wallet.GetUserWithWallet(rt, internalAPIHost, token, metaRemoteIP) res := NewResult(user, err) @@ -50,7 +51,7 @@ func Middleware(provider Provider) mux.MiddlewareFunc { addr := ip.AddressForRequest(r) res = provider(token[0], addr) if res.err != nil { - logger.LogF(monitor.F{"ip": addr}).Debugf("error authenticating user") + logger.WithFields(logrus.Fields{"ip": addr}).Debugf("error authenticating user") } res.authAttempted = true } diff --git a/app/proxy/accounts_test.go b/app/proxy/accounts_test.go index 51b65fe3..5a135992 100644 --- a/app/proxy/accounts_test.go +++ b/app/proxy/accounts_test.go @@ -39,7 +39,7 @@ func TestWithWrongAuthToken(t *testing.T) { rr := httptest.NewRecorder() rt := sdkrouter.New(config.GetLbrynetServers()) - handler := sdkrouter.Middleware(rt)(auth.Middleware(auth.NewWalletAndInternalAPIProvider(rt, ts.URL))(http.HandlerFunc(Handle))) + handler := sdkrouter.Middleware(rt)(auth.Middleware(auth.NewIAPIProvider(rt, ts.URL))(http.HandlerFunc(Handle))) handler.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) diff --git a/app/proxy/cache.go b/app/proxy/cache.go index e5fc293a..5fb52f6b 100644 --- a/app/proxy/cache.go +++ b/app/proxy/cache.go @@ -10,16 +10,24 @@ import ( "github.com/lbryio/lbrytv/internal/monitor" "github.com/patrickmn/go-cache" + "github.com/sirupsen/logrus" ) -// CacheLogger is for logging query cache-related messages -var CacheLogger = monitor.NewModuleLogger("proxy_cache") +var ( + globalCache responseCache + cacheLogger = monitor.NewModuleLogger("proxy_cache") +) + +func init() { + globalCache = cacheStorage{c: cache.New(2*time.Minute, 10*time.Minute)} +} -// ResponseCache interface describes methods for SDK response cache saving and retrieval -type ResponseCache interface { +// responseCache interface describes methods for SDK response cache saving and retrieval +type responseCache interface { Save(method string, params interface{}, r interface{}) Retrieve(method string, params interface{}) interface{} Count() int + getKey(method string, params interface{}) (string, error) flush() } @@ -28,16 +36,9 @@ type cacheStorage struct { c *cache.Cache } -var responseCache ResponseCache - -// InitResponseCache initializes module-level responseCache variable -func InitResponseCache(c ResponseCache) { - responseCache = c -} - // Save puts a response object into cache, making it available for a later retrieval by method and query params func (s cacheStorage) Save(method string, params interface{}, r interface{}) { - l := CacheLogger.LogF(monitor.F{"method": method}) + l := cacheLogger.WithFields(logrus.Fields{"method": method}) cacheKey, err := s.getKey(method, params) if err != nil { l.Errorf("unable to produce key for params: %v", params) @@ -49,7 +50,7 @@ func (s cacheStorage) Save(method string, params interface{}, r interface{}) { // Retrieve earlier saved server response by method and query params func (s cacheStorage) Retrieve(method string, params interface{}) interface{} { - l := CacheLogger.LogF(monitor.F{"method": method}) + l := cacheLogger.WithFields(logrus.Fields{"method": method}) cacheKey, err := s.getKey(method, params) if err != nil { l.Errorf("unable to produce key for params: %v", params) @@ -88,7 +89,3 @@ func (s cacheStorage) flush() { func (s cacheStorage) Count() int { return s.c.ItemCount() } - -func init() { - InitResponseCache(cacheStorage{c: cache.New(2*time.Minute, 10*time.Minute)}) -} diff --git a/app/proxy/cache_test.go b/app/proxy/cache_test.go index 7b2ab088..424d90e7 100644 --- a/app/proxy/cache_test.go +++ b/app/proxy/cache_test.go @@ -33,20 +33,20 @@ func TestCache(t *testing.T) { t.Fatal(err) } - responseCache.flush() - assert.Nil(t, responseCache.Retrieve("resolve", query.Params)) - responseCache.Save("resolve", query.Params, response.Result) - assert.Equal(t, 1, responseCache.Count()) - assert.Equal(t, response.Result, responseCache.Retrieve("resolve", query.Params)) + globalCache.flush() + assert.Nil(t, globalCache.Retrieve("resolve", query.Params)) + globalCache.Save("resolve", query.Params, response.Result) + assert.Equal(t, 1, globalCache.Count()) + assert.Equal(t, response.Result, globalCache.Retrieve("resolve", query.Params)) } func TestCacheGetKey(t *testing.T) { - responseCache.flush() - key, err := responseCache.getKey("resolve", map[string]interface{}{"urls": "one"}) + globalCache.flush() + key, err := globalCache.getKey("resolve", map[string]interface{}{"urls": "one"}) assert.Equal(t, "resolve|3600a4eed065d3ae3dd503cca56ce56ae6bd4778047fa1b17c999301681d3a1d", key) assert.NoError(t, err) - key, err = responseCache.getKey("wallet_balance", nil) + key, err = globalCache.getKey("wallet_balance", nil) assert.Equal(t, "wallet_balance|nil", key) assert.NoError(t, err) } diff --git a/app/proxy/const.go b/app/proxy/const.go index 6782afd2..c5958526 100644 --- a/app/proxy/const.go +++ b/app/proxy/const.go @@ -129,7 +129,6 @@ const MethodAccountBalance = "account_balance" const MethodStatus = "status" const MethodResolve = "resolve" const MethodClaimSearch = "claim_search" -const MethodCommentList = "comment_list" const paramAccountID = "account_id" const paramWalletID = "wallet_id" diff --git a/app/proxy/errors.go b/app/proxy/errors.go index 65ef5a95..e2a15122 100644 --- a/app/proxy/errors.go +++ b/app/proxy/errors.go @@ -35,7 +35,7 @@ func (e RPCError) JSON() []byte { JSONRPC: "2.0", }, "", " ") if err != nil { - Logger.Errorf("rpc error to json: %v", err) + logger.Log().Errorf("rpc error to json: %v", err) } return b } diff --git a/app/proxy/handlers.go b/app/proxy/handlers.go index 47b27485..79b00639 100644 --- a/app/proxy/handlers.go +++ b/app/proxy/handlers.go @@ -11,6 +11,7 @@ import ( "github.com/lbryio/lbrytv/app/wallet" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/internal/responses" + "github.com/ybbus/jsonrpc" ) @@ -42,6 +43,8 @@ func Handle(w http.ResponseWriter, r *http.Request) { return } + logger.Log().Tracef("call to method %s", req.Method) + var userID int var sdkAddress string if MethodNeedsAuth(req.Method) { diff --git a/app/proxy/caller.go b/app/proxy/proxy.go similarity index 84% rename from app/proxy/caller.go rename to app/proxy/proxy.go index 5cb84ba5..5f5753e4 100644 --- a/app/proxy/caller.go +++ b/app/proxy/proxy.go @@ -20,21 +20,23 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lbryio/lbrytv/app/sdkrouter" + "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/lbrynet" "github.com/lbryio/lbrytv/internal/metrics" "github.com/lbryio/lbrytv/internal/monitor" + "github.com/sirupsen/logrus" ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" - "github.com/sirupsen/logrus" "github.com/ybbus/jsonrpc" ) -var Logger = monitor.NewProxyLogger() +var logger = monitor.NewModuleLogger("proxy") const ( walletLoadRetries = 3 walletLoadRetryWait = 100 * time.Millisecond + rpcTimeout = 30 * time.Second ) // Caller patches through JSON-RPC requests from clients, doing pre/post-processing, @@ -51,7 +53,7 @@ type Caller struct { func NewCaller(endpoint string, userID int) *Caller { return &Caller{ client: jsonrpc.NewClientWithOpts(endpoint, &jsonrpc.RPCClientOpts{ - HTTPClient: &http.Client{Timeout: 30 * time.Second}, + HTTPClient: &http.Client{Timeout: rpcTimeout}, }), endpoint: endpoint, userID: userID, @@ -73,14 +75,14 @@ func (c *Caller) Call(req *jsonrpc.RPCRequest) []byte { r, err := c.call(req) if err != nil { monitor.CaptureException(err, map[string]string{"req": spew.Sdump(req), "response": fmt.Sprintf("%v", r)}) - Logger.Errorf("error calling lbrynet: %v, request: %s", err, spew.Sdump(req)) + logger.Log().Errorf("error calling lbrynet: %v, request: %s", err, spew.Sdump(req)) return marshalError(err) } serialized, err := json.MarshalIndent(r, "", " ") if err != nil { monitor.CaptureException(err) - Logger.Errorf("error marshaling response: %v", err) + logger.Log().Errorf("error marshaling response: %v", err) return marshalError(NewInternalError(err)) } @@ -124,7 +126,7 @@ func (c *Caller) call(req *jsonrpc.RPCRequest) (*jsonrpc.RPCResponse, error) { } if q.isCacheable() { - responseCache.Save(q.Method(), q.Params(), r) + globalCache.Save(q.Method(), q.Params(), r) } return r, nil } @@ -149,7 +151,7 @@ func (c *Caller) callQueryWithRetry(q *Query) (*jsonrpc.RPCResponse, error) { // Generally a HTTP transport failure (connect error etc) if err != nil { - Logger.Errorf("error sending query to %v: %v", c.endpoint, err) + logger.Log().Errorf("error sending query to %v: %v", c.endpoint, err) return nil, err } @@ -163,7 +165,7 @@ func (c *Caller) callQueryWithRetry(q *Query) (*jsonrpc.RPCResponse, error) { // Alert sentry on the last failed wallet load attempt if err != nil && i >= walletLoadRetries-1 { errMsg := "gave up on manually adding a wallet: %v" - Logger.Logger().WithFields(logrus.Fields{ + logger.WithFields(logrus.Fields{ "user_id": c.userID, "endpoint": c.endpoint, }).Errorf(errMsg, err) @@ -182,10 +184,27 @@ func (c *Caller) callQueryWithRetry(q *Query) (*jsonrpc.RPCResponse, error) { } if (r != nil && r.Error != nil) || err != nil { - Logger.LogFailedQuery(q.Method(), c.endpoint, c.userID, duration, q.Params(), r.Error) + logger.WithFields(logrus.Fields{ + "method": q.Method(), + "params": q.Params(), + "endpoint": c.endpoint, + "user_id": c.userID, + "duration": duration, + "response": r.Error, + }).Error("error from the target endpoint") failureMetrics.Observe(duration) } else { - Logger.LogSuccessfulQuery(q.Method(), c.endpoint, c.userID, duration, q.Params(), r) + fields := logrus.Fields{ + "method": q.Method(), + "params": q.Params(), + "endpoint": c.endpoint, + "user_id": c.userID, + "duration": duration, + } + if config.ShouldLogResponses() { + fields["response"] = r + } + logger.WithFields(fields).Info("call processed") } return r, err diff --git a/app/proxy/caller_test.go b/app/proxy/proxy_test.go similarity index 98% rename from app/proxy/caller_test.go rename to app/proxy/proxy_test.go index ce79ef3b..e0134913 100644 --- a/app/proxy/caller_test.go +++ b/app/proxy/proxy_test.go @@ -76,7 +76,7 @@ func TestCallerCallWalletBalance(t *testing.T) { err := wallet.Create(addr, dummyUserID) require.NoError(t, err) - hook := logrusTest.NewLocal(Logger.Logger()) + hook := logrusTest.NewLocal(logger.Logger) result = NewCaller(addr, dummyUserID).Call(request) var accountBalanceResponse struct { @@ -210,7 +210,7 @@ func TestCallerCallSDKError(t *testing.T) { }` c := NewCaller(srv.URL, 0) - hook := logrusTest.NewLocal(Logger.Logger()) + hook := logrusTest.NewLocal(logger.Logger) response := c.Call(jsonrpc.NewRequest("resolve", map[string]interface{}{"urls": "what"})) var rpcResponse jsonrpc.RPCResponse json.Unmarshal(response, &rpcResponse) diff --git a/app/proxy/query.go b/app/proxy/query.go index ac92432b..524f6f29 100644 --- a/app/proxy/query.go +++ b/app/proxy/query.go @@ -6,8 +6,8 @@ import ( "fmt" "strings" - "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/internal/responses" + "github.com/sirupsen/logrus" "github.com/ybbus/jsonrpc" ) @@ -100,14 +100,14 @@ func (q *Query) cacheHit() *jsonrpc.RPCResponse { return nil } - cached := responseCache.Retrieve(q.Method(), q.Params()) + cached := globalCache.Retrieve(q.Method(), q.Params()) if cached == nil { return nil } s, err := json.Marshal(cached) if err != nil { - Logger.Errorf("error marshalling cached response") + logger.Log().Errorf("error marshalling cached response") return nil } @@ -117,7 +117,7 @@ func (q *Query) cacheHit() *jsonrpc.RPCResponse { return nil } - monitor.LogCachedQuery(q.Method()) + logger.WithFields(logrus.Fields{"method": q.Method()}).Debug("cached query") return response } diff --git a/app/publish/publish.go b/app/publish/publish.go index 339e37a8..e4bacc65 100644 --- a/app/publish/publish.go +++ b/app/publish/publish.go @@ -13,6 +13,7 @@ import ( "github.com/lbryio/lbrytv/app/proxy" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/internal/responses" + "github.com/sirupsen/logrus" "github.com/gorilla/mux" ) @@ -89,7 +90,7 @@ func (h Handler) CanHandle(r *http.Request, _ *mux.RouteMatch) bool { } func (h Handler) saveFile(r *http.Request, userID int) (*os.File, error) { - log := logger.LogF(monitor.F{"user_id": userID}) + log := logger.WithFields(logrus.Fields{"user_id": userID}) file, header, err := r.FormFile(fileFieldName) if err != nil { diff --git a/app/sdkrouter/sdkrouter.go b/app/sdkrouter/sdkrouter.go index b336547b..03c007ec 100644 --- a/app/sdkrouter/sdkrouter.go +++ b/app/sdkrouter/sdkrouter.go @@ -10,6 +10,7 @@ import ( "github.com/lbryio/lbrytv/internal/metrics" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/models" + "github.com/sirupsen/logrus" ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" ) @@ -50,18 +51,18 @@ func New(servers map[string]string) *Router { func (r *Router) GetAll() []*models.LbrynetServer { r.reloadServersFromDB() - logger.WithField("lock", "mu").Trace("waiting for read lock in GetAll") + logger.WithFields(logrus.Fields{"lock": "mu"}).Trace("waiting for read lock in GetAll") r.mu.RLock() - logger.WithField("lock", "mu").Trace("got read lock in GetAll") + logger.WithFields(logrus.Fields{"lock": "mu"}).Trace("got read lock in GetAll") defer r.mu.RUnlock() return r.servers } func (r *Router) RandomServer() *models.LbrynetServer { r.reloadServersFromDB() - logger.WithField("lock", "mu").Trace("waiting for read lock in RandomServer") + logger.WithFields(logrus.Fields{"lock": "mu"}).Trace("waiting for read lock in RandomServer") r.mu.RLock() - logger.WithField("lock", "mu").Trace("got read lock in RandomServer") + logger.WithFields(logrus.Fields{"lock": "mu"}).Trace("got read lock in RandomServer") defer r.mu.RUnlock() return r.servers[rand.Intn(len(r.servers))] } @@ -90,9 +91,9 @@ func (r *Router) setServers(servers []*models.LbrynetServer) { // we do this partially to make sure that ids are assigned to servers more consistently, // and partially to make tests consistent (since Go maps are not ordered) sort.Slice(servers, func(i, j int) bool { return servers[i].Name < servers[j].Name }) - logger.WithField("lock", "mu").Trace("waiting for write lock in setServers") + logger.WithFields(logrus.Fields{"lock": "mu"}).Trace("waiting for write lock in setServers") r.mu.Lock() - logger.WithField("lock", "mu").Trace("got write lock in setServers") + logger.WithFields(logrus.Fields{"lock": "mu"}).Trace("got write lock in setServers") defer r.mu.Unlock() r.servers = servers logger.Log().Debugf("updated server list to %d servers", len(r.servers)) @@ -119,17 +120,17 @@ func (r *Router) updateLoadAndMetrics() { walletList, err := ljsonrpc.NewClient(server.Address).WalletList("", 1, 1) if err != nil { logger.Log().Errorf("lbrynet instance %s is not responding: %v", server.Address, err) - logger.WithField("lock", "loadMu").Trace("waiting for write lock in updateLoadAndMetrics 1") + logger.WithFields(logrus.Fields{"lock": "loadMu"}).Trace("waiting for write lock in updateLoadAndMetrics 1") r.loadMu.Lock() - logger.WithField("lock", "loadMu").Trace("got write lock in updateLoadAndMetrics 1") + logger.WithFields(logrus.Fields{"lock": "loadMu"}).Trace("got write lock in updateLoadAndMetrics 1") delete(r.load, server) r.loadMu.Unlock() metric.Set(-1.0) // TODO: maybe mark this instance as unresponsive so new users are assigned to other instances } else { - logger.WithField("lock", "loadMu").Trace("waiting for write lock in updateLoadAndMetrics 2") + logger.WithFields(logrus.Fields{"lock": "loadMu"}).Trace("waiting for write lock in updateLoadAndMetrics 2") r.loadMu.Lock() - logger.WithField("lock", "loadMu").Trace("got write lock in updateLoadAndMetrics 2") + logger.WithFields(logrus.Fields{"lock": "loadMu"}).Trace("got write lock in updateLoadAndMetrics 2") r.load[server] = walletList.TotalPages r.loadMu.Unlock() metric.Set(float64(walletList.TotalPages)) @@ -144,9 +145,9 @@ func (r *Router) LeastLoaded() *models.LbrynetServer { var best *models.LbrynetServer var min uint64 - logger.WithField("lock", "loadMu").Trace("waiting for read lock in LeastLoaded") + logger.WithFields(logrus.Fields{"lock": "loadMu"}).Trace("waiting for read lock in LeastLoaded") r.loadMu.RLock() - logger.WithField("lock", "loadMu").Trace("got read lock in LeastLoaded") + logger.WithFields(logrus.Fields{"lock": "loadMu"}).Trace("got read lock in LeastLoaded") defer r.loadMu.RUnlock() if len(r.load) == 0 { diff --git a/app/wallet/wallet.go b/app/wallet/wallet.go index 04d0ffd6..4dae75eb 100644 --- a/app/wallet/wallet.go +++ b/app/wallet/wallet.go @@ -29,7 +29,7 @@ const pgUniqueConstraintViolation = "23505" // Retrieve gets user by internal-apis auth token. If the user does not have a wallet yet, they // are assigned an SDK and a wallet is created for them on that SDK. func GetUserWithWallet(rt *sdkrouter.Router, internalAPIHost, token, metaRemoteIP string) (*models.User, error) { - log := logger.LogF(monitor.F{monitor.TokenF: token}) + log := logger.WithFields(logrus.Fields{monitor.TokenF: token}) remoteUser, err := getRemoteUser(internalAPIHost, token, metaRemoteIP) if err != nil { @@ -95,7 +95,7 @@ func assignSDKServerToUser(user *models.User, router *sdkrouter.Router, log *log } func createDBUser(id int) (*models.User, error) { - log := logger.LogF(monitor.F{"id": id}) + log := logger.WithFields(logrus.Fields{"id": id}) u := &models.User{ID: id} err := u.InsertG(boil.Infer()) @@ -133,7 +133,7 @@ func Create(serverAddress string, userID int) error { return nil } - log := logger.LogF(monitor.F{"user_id": userID, "sdk": serverAddress}) + log := logger.WithFields(logrus.Fields{"user_id": userID, "sdk": serverAddress}) if errors.Is(err, lbrynet.ErrWalletExists) { log.Warn(err.Error()) @@ -173,7 +173,7 @@ func createWallet(addr string, userID int) error { if err != nil { return lbrynet.NewWalletError(userID, err) } - logger.LogF(monitor.F{"user_id": userID, "sdk": addr}).Info("wallet created") + logger.WithFields(logrus.Fields{"user_id": userID, "sdk": addr}).Info("wallet created") return nil } @@ -186,7 +186,7 @@ func loadWallet(addr string, userID int) error { if err != nil { return lbrynet.NewWalletError(userID, err) } - logger.LogF(monitor.F{"user_id": userID, "sdk": addr}).Info("wallet loaded") + logger.WithFields(logrus.Fields{"user_id": userID, "sdk": addr}).Info("wallet loaded") return nil } @@ -199,6 +199,6 @@ func UnloadWallet(addr string, userID int) error { if err != nil { return lbrynet.NewWalletError(userID, err) } - logger.LogF(monitor.F{"user_id": userID, "sdk": addr}).Info("wallet unloaded") + logger.WithFields(logrus.Fields{"user_id": userID, "sdk": addr}).Info("wallet unloaded") return nil } diff --git a/internal/ip/ip.go b/internal/ip/ip.go index e80d9d2b..4b1618bf 100644 --- a/internal/ip/ip.go +++ b/internal/ip/ip.go @@ -66,7 +66,7 @@ func IsPrivateSubnet(ipAddress net.IP) bool { return false } -// GetIPAddressForRequest returns the real IP address of the request +// AddressForRequest returns the real IP address of the request func AddressForRequest(r *http.Request) string { for _, h := range []string{"X-Forwarded-For", "X-Real-Ip"} { addresses := strings.Split(r.Header.Get(h), ",") diff --git a/internal/monitor/middleware.go b/internal/monitor/middleware.go index 4cc2bf56..6a2fcefb 100644 --- a/internal/monitor/middleware.go +++ b/internal/monitor/middleware.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/getsentry/sentry-go" + "github.com/sirupsen/logrus" ) const responseSnippetLength = 500 @@ -86,8 +87,7 @@ func ErrorLoggingMiddleware(next http.Handler) http.Handler { } func CaptureRequestError(err error, r *http.Request, w http.ResponseWriter, params ...map[string]interface{}) { - extra := map[string]interface{}{} - + extra := logrus.Fields{} if len(params) > 0 { extra = params[0] } @@ -100,12 +100,7 @@ func CaptureRequestError(err error, r *http.Request, w http.ResponseWriter, para extra["response"] = lw.ResponseSnippet } - logFields := F{} - - for k, v := range extra { - logFields[k] = v - } - httpLogger.LogF(logFields).Error(err) + httpLogger.WithFields(extra).Error(err) CaptureException(err) // if hub := sentry.GetHubFromContext(r.Context()); hub != nil { // hub.WithScope(func(scope *sentry.Scope) { diff --git a/internal/monitor/module_logger.go b/internal/monitor/module_logger.go index 7b2acd8a..e5fb7c71 100644 --- a/internal/monitor/module_logger.go +++ b/internal/monitor/module_logger.go @@ -10,50 +10,32 @@ import ( // ModuleLogger contains module-specific logger details. type ModuleLogger struct { - ModuleName string Logger *logrus.Logger - Level logrus.Level + moduleName string } -// F can be supplied to ModuleLogger's Log function for providing additional log context. -type F map[string]interface{} - // NewModuleLogger creates a new ModuleLogger instance carrying module name // for later `Log()` calls. func NewModuleLogger(moduleName string) ModuleLogger { - logger := getBaseLogger() - l := ModuleLogger{ - ModuleName: moduleName, + logger := logrus.New() + configureLogger(logger) + return ModuleLogger{ + moduleName: moduleName, Logger: logger, - Level: logger.GetLevel(), } - l.Logger.SetLevel(l.Level) - return l } -// LogF is a deprecated method, an equivalent WithFields/WithField should be used. -func (l ModuleLogger) LogF(fields F) *logrus.Entry { return l.WithFields(fields) } - // WithFields returns a new log entry containing additional info provided by fields, // which can be called upon with a corresponding logLevel. // Example: // logger.WithFields(F{"query": "..."}).Info("query error") -func (l ModuleLogger) WithFields(fields F) *logrus.Entry { - logFields := logrus.Fields{} - logFields["module"] = l.ModuleName - for k, v := range fields { - if k == TokenF && v != "" && config.IsProduction() { - logFields[k] = ValueMask - } else { - logFields[k] = v - } - } - return l.Logger.WithFields(logFields) -} +func (l ModuleLogger) WithFields(fields logrus.Fields) *logrus.Entry { + fields["module"] = l.moduleName -// WithField is a shortcut for when a single log entry field is needed. -func (l ModuleLogger) WithField(key string, value interface{}) *logrus.Entry { - return l.WithFields(F{key: value}) + if v, ok := fields[TokenF]; ok && v != "" && config.IsProduction() { + fields[TokenF] = ValueMask + } + return l.Logger.WithFields(fields) } // Log returns a new log entry for the module @@ -61,7 +43,7 @@ func (l ModuleLogger) WithField(key string, value interface{}) *logrus.Entry { // Example: // Log().Info("query error") func (l ModuleLogger) Log() *logrus.Entry { - return l.Logger.WithFields(logrus.Fields{"module": l.ModuleName}) + return l.Logger.WithFields(logrus.Fields{"module": l.moduleName}) } // Disable turns off logging output for this module logger diff --git a/internal/monitor/monitor.go b/internal/monitor/monitor.go index 09f67795..6ddf01b1 100644 --- a/internal/monitor/monitor.go +++ b/internal/monitor/monitor.go @@ -21,31 +21,27 @@ var textFormatter = logrus.TextFormatter{FullTimestamp: true, TimestampFormat: " // init magic is needed so logging is set up without calling it in every package explicitly func init() { - SetupLogging() + configureLogger(logrus.StandardLogger()) } -// SetupLogging initializes and sets a few parameters for the logging subsystem. -func SetupLogging() { +// configureLogger sets a few parameters for the logging subsystem. +func configureLogger(l *logrus.Logger) { var mode string if config.IsProduction() { mode = "production" - logrus.SetLevel(logrus.InfoLevel) - Logger.SetLevel(logrus.InfoLevel) - logrus.SetFormatter(&jsonFormatter) - Logger.SetFormatter(&jsonFormatter) + l.SetLevel(logrus.InfoLevel) + l.SetFormatter(&jsonFormatter) } else { mode = "develop" - logrus.SetLevel(logrus.TraceLevel) - Logger.SetLevel(logrus.TraceLevel) - logrus.SetFormatter(&textFormatter) - Logger.SetFormatter(&textFormatter) + l.SetLevel(logrus.TraceLevel) + l.SetFormatter(&textFormatter) } - Logger.Infof("%v, running in %v mode", version.GetFullBuildName(), mode) - Logger.Infof("logging initialized (loglevel=%v)", Logger.Level.String()) + l.Infof("%s, running in %s mode", version.GetFullBuildName(), mode) + l.Infof("logging initialized (loglevel=%s)", l.Level) configureSentry(version.GetDevVersion(), mode) } @@ -62,85 +58,3 @@ func LogSuccessfulQuery(method string, time float64, params interface{}, respons } Logger.WithFields(fields).Info("call processed") } - -// LogCachedQuery logs a cache hit for a given method -func LogCachedQuery(method string) { - Logger.WithFields(logrus.Fields{ - "method": method, - }).Debug("cached query") -} - -type QueryMonitor interface { - LogSuccessfulQuery(method string, time float64, params interface{}, response interface{}) - LogFailedQuery(method string, params interface{}, errorResponse interface{}) - Error(message string) - Errorf(message string, args ...interface{}) - Logger() *logrus.Logger -} - -func getBaseLogger() *logrus.Logger { - logger := logrus.New() - if config.IsProduction() { - logger.SetLevel(logrus.InfoLevel) - logger.SetFormatter(&jsonFormatter) - } else { - logger.SetLevel(logrus.DebugLevel) - logger.SetFormatter(&textFormatter) - } - return logger -} - -type ProxyLogger struct { - logger *logrus.Logger - entry *logrus.Entry - Level logrus.Level -} - -func NewProxyLogger() *ProxyLogger { - logger := getBaseLogger() - - l := ProxyLogger{ - logger: logger, - entry: logger.WithFields(logrus.Fields{"module": "proxy"}), - Level: logger.GetLevel(), - } - return &l -} - -func (l *ProxyLogger) LogSuccessfulQuery(method, endpoint string, userID int, time float64, params interface{}, response interface{}) { - fields := logrus.Fields{ - "method": method, - "duration": time, - "params": params, - "endpoint": endpoint, - "user_id": userID, - } - if config.ShouldLogResponses() { - fields["response"] = response - } - l.entry.WithFields(fields).Info("call processed") - -} - -func (l *ProxyLogger) LogFailedQuery(method, endpoint string, userID int, time float64, params interface{}, errorResponse interface{}) { - l.entry.WithFields(logrus.Fields{ - "method": method, - "duration": time, - "params": params, - "endpoint": endpoint, - "user_id": userID, - "response": errorResponse, - }).Error("error from the target endpoint") -} - -func (l *ProxyLogger) Error(message string) { - l.entry.Error(message) -} - -func (l *ProxyLogger) Errorf(message string, args ...interface{}) { - l.entry.Errorf(message, args...) -} - -func (l *ProxyLogger) Logger() *logrus.Logger { - return l.logger -} diff --git a/internal/monitor/monitor_test.go b/internal/monitor/monitor_test.go index 314235cf..9b5f8275 100644 --- a/internal/monitor/monitor_test.go +++ b/internal/monitor/monitor_test.go @@ -5,7 +5,7 @@ import ( "github.com/lbryio/lbrytv/config" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" "github.com/ybbus/jsonrpc" @@ -33,7 +33,7 @@ func TestLogSuccessfulQuery(t *testing.T) { LogSuccessfulQuery("resolve", 0.025, map[string]string{"urls": "one"}, response) require.Equal(t, 1, len(hook.Entries)) - require.Equal(t, log.InfoLevel, hook.LastEntry().Level) + require.Equal(t, logrus.InfoLevel, hook.LastEntry().Level) require.Equal(t, "resolve", hook.LastEntry().Data["method"]) require.Equal(t, map[string]string{"urls": "one"}, hook.LastEntry().Data["params"]) require.Equal(t, 0.025, hook.LastEntry().Data["duration"]) @@ -42,7 +42,7 @@ func TestLogSuccessfulQuery(t *testing.T) { LogSuccessfulQuery("account_balance", 0.025, nil, nil) require.Equal(t, 2, len(hook.Entries)) - require.Equal(t, log.InfoLevel, hook.LastEntry().Level) + require.Equal(t, logrus.InfoLevel, hook.LastEntry().Level) require.Equal(t, "account_balance", hook.LastEntry().Data["method"]) require.Equal(t, nil, hook.LastEntry().Data["params"]) require.Equal(t, 0.025, hook.LastEntry().Data["duration"]) @@ -52,74 +52,74 @@ func TestLogSuccessfulQuery(t *testing.T) { hook.Reset() } -func TestLogSuccessfulQueryWithResponse(t *testing.T) { - l := NewProxyLogger() - hook := test.NewLocal(l.logger) - - config.Override("ShouldLogResponses", true) - defer config.RestoreOverridden() - - response := &jsonrpc.RPCResponse{ - Result: map[string]interface{}{ - "available": "20.02", - "reserved": "0.0", - "reserved_subtotals": map[string]string{ - "claims": "0.0", - "supports": "0.0", - "tips": "0.0", - }, - "total": "20.02", - }, - } - - l.LogSuccessfulQuery("resolve", "sdk1.local", 123, 0.025, map[string]string{"urls": "one"}, response) - - require.Equal(t, 1, len(hook.Entries)) - require.Equal(t, log.InfoLevel, hook.LastEntry().Level) - require.Equal(t, "resolve", hook.LastEntry().Data["method"]) - require.Equal(t, "sdk1.local", hook.LastEntry().Data["endpoint"]) - require.Equal(t, 123, hook.LastEntry().Data["user_id"]) - require.Equal(t, map[string]string{"urls": "one"}, hook.LastEntry().Data["params"]) - require.Equal(t, 0.025, hook.LastEntry().Data["duration"]) - require.Equal(t, response, hook.LastEntry().Data["response"]) - require.Equal(t, "call processed", hook.LastEntry().Message) - - hook.Reset() -} - -func TestLogFailedQuery(t *testing.T) { - l := NewProxyLogger() - hook := test.NewLocal(l.logger) - - response := &jsonrpc.RPCError{ - Code: 111, - // TODO: Uncomment after lbrynet 0.31 release - // Message: "Invalid method requested: unknown_method.", - Message: "Method Not Found", - } - queryParams := map[string]string{"param1": "value1"} - l.LogFailedQuery("unknown_method", "sdk2.local", 566, 2.34, queryParams, response) - - require.Equal(t, 1, len(hook.Entries)) - require.Equal(t, log.ErrorLevel, hook.LastEntry().Level) - require.Equal(t, "unknown_method", hook.LastEntry().Data["method"]) - require.Equal(t, "sdk2.local", hook.LastEntry().Data["endpoint"]) - require.Equal(t, 566, hook.LastEntry().Data["user_id"]) - require.Equal(t, queryParams, hook.LastEntry().Data["params"]) - require.Equal(t, response, hook.LastEntry().Data["response"]) - require.Equal(t, 2.34, hook.LastEntry().Data["duration"]) - require.Equal(t, "error from the target endpoint", hook.LastEntry().Message) - - hook.Reset() -} +//func TestLogSuccessfulQueryWithResponse(t *testing.T) { +// l := NewProxyLogger() +// hook := test.NewLocal(l.logger) +// +// config.Override("ShouldLogResponses", true) +// defer config.RestoreOverridden() +// +// response := &jsonrpc.RPCResponse{ +// Result: map[string]interface{}{ +// "available": "20.02", +// "reserved": "0.0", +// "reserved_subtotals": map[string]string{ +// "claims": "0.0", +// "supports": "0.0", +// "tips": "0.0", +// }, +// "total": "20.02", +// }, +// } +// +// l.LogSuccessfulQuery("resolve", "sdk1.local", 123, 0.025, map[string]string{"urls": "one"}, response) +// +// require.Equal(t, 1, len(hook.Entries)) +// require.Equal(t, log.InfoLevel, hook.LastEntry().Level) +// require.Equal(t, "resolve", hook.LastEntry().Data["method"]) +// require.Equal(t, "sdk1.local", hook.LastEntry().Data["endpoint"]) +// require.Equal(t, 123, hook.LastEntry().Data["user_id"]) +// require.Equal(t, map[string]string{"urls": "one"}, hook.LastEntry().Data["params"]) +// require.Equal(t, 0.025, hook.LastEntry().Data["duration"]) +// require.Equal(t, response, hook.LastEntry().Data["response"]) +// require.Equal(t, "call processed", hook.LastEntry().Message) +// +// hook.Reset() +//} +// +//func TestLogFailedQuery(t *testing.T) { +// l := NewProxyLogger() +// hook := test.NewLocal(l.logger) +// +// response := &jsonrpc.RPCError{ +// Code: 111, +// // TODO: Uncomment after lbrynet 0.31 release +// // Message: "Invalid method requested: unknown_method.", +// Message: "Method Not Found", +// } +// queryParams := map[string]string{"param1": "value1"} +// l.LogFailedQuery("unknown_method", "sdk2.local", 566, 2.34, queryParams, response) +// +// require.Equal(t, 1, len(hook.Entries)) +// require.Equal(t, log.ErrorLevel, hook.LastEntry().Level) +// require.Equal(t, "unknown_method", hook.LastEntry().Data["method"]) +// require.Equal(t, "sdk2.local", hook.LastEntry().Data["endpoint"]) +// require.Equal(t, 566, hook.LastEntry().Data["user_id"]) +// require.Equal(t, queryParams, hook.LastEntry().Data["params"]) +// require.Equal(t, response, hook.LastEntry().Data["response"]) +// require.Equal(t, 2.34, hook.LastEntry().Data["duration"]) +// require.Equal(t, "error from the target endpoint", hook.LastEntry().Message) +// +// hook.Reset() +//} func TestModuleLoggerLogF(t *testing.T) { l := NewModuleLogger("storage") hook := test.NewLocal(l.Logger) - l.LogF(F{"number": 1}).Info("error!") + l.WithFields(logrus.Fields{"number": 1}).Info("error!") require.Equal(t, 1, len(hook.Entries)) - require.Equal(t, log.InfoLevel, hook.LastEntry().Level) + require.Equal(t, logrus.InfoLevel, hook.LastEntry().Level) require.Equal(t, 1, hook.LastEntry().Data["number"]) require.Equal(t, "storage", hook.LastEntry().Data["module"]) require.Equal(t, "error!", hook.LastEntry().Message) @@ -133,7 +133,7 @@ func TestModuleLoggerLog(t *testing.T) { l.Log().Info("error!") require.Equal(t, 1, len(hook.Entries)) - require.Equal(t, log.InfoLevel, hook.LastEntry().Level) + require.Equal(t, logrus.InfoLevel, hook.LastEntry().Level) require.Equal(t, "storage", hook.LastEntry().Data["module"]) require.Equal(t, "error!", hook.LastEntry().Message) @@ -147,7 +147,7 @@ func TestModuleLoggerMasksTokens(t *testing.T) { config.Override("Debug", false) defer config.RestoreOverridden() - l.LogF(F{"token": "SecRetT0Ken", "email": "abc@abc.com"}).Info("something happened") + l.WithFields(logrus.Fields{"token": "SecRetT0Ken", "email": "abc@abc.com"}).Info("something happened") require.Equal(t, "abc@abc.com", hook.LastEntry().Data["email"]) require.Equal(t, ValueMask, hook.LastEntry().Data["token"]) diff --git a/internal/storage/conn.go b/internal/storage/conn.go index 45a7d516..8a2aef79 100644 --- a/internal/storage/conn.go +++ b/internal/storage/conn.go @@ -5,6 +5,7 @@ import ( "time" "github.com/lbryio/lbrytv/internal/monitor" + "github.com/sirupsen/logrus" _ "github.com/jinzhu/gorm/dialects/postgres" // Dialect import "github.com/jmoiron/sqlx" @@ -61,7 +62,7 @@ func InitConn(params ConnParams) *Connection { // Connect initiates a connection to the database server defined in c.params. func (c *Connection) Connect() error { dsn := MakeDSN(c.params) - c.logger.LogF(monitor.F{"dsn": dsn}).Info("connecting to the DB") + c.logger.WithFields(logrus.Fields{"dsn": dsn}).Info("connecting to the DB") var err error var db *sqlx.DB for i := 0; i < maxDBConnectAttempts; i++ { @@ -75,7 +76,7 @@ func (c *Connection) Connect() error { } if err != nil { - c.logger.LogF(monitor.F{"dsn": dsn}).Info("DB connection failed") + c.logger.WithFields(logrus.Fields{"dsn": dsn}).Info("DB connection failed") return err } c.DB = db diff --git a/internal/storage/maintenance.go b/internal/storage/maintenance.go index d683d30e..3af53f70 100644 --- a/internal/storage/maintenance.go +++ b/internal/storage/maintenance.go @@ -4,11 +4,10 @@ import ( "fmt" "strings" - "github.com/lbryio/lbrytv/internal/monitor" - "github.com/gobuffalo/packr/v2" "github.com/lib/pq" migrate "github.com/rubenv/sql-migrate" + "github.com/sirupsen/logrus" "github.com/volatiletech/sqlboiler/queries" ) @@ -24,7 +23,7 @@ func (c *Connection) MigrateUp() { if err != nil { c.logger.Log().Panicf("failed to migrate the database up: %v", err) } - c.logger.LogF(monitor.F{"migrations_number": n}).Info("migrated the database up") + c.logger.WithFields(logrus.Fields{"migrations_number": n}).Info("migrated the database up") } // MigrateDown undoes the previous migration. @@ -37,7 +36,7 @@ func (c *Connection) MigrateDown() { if err != nil { c.logger.Log().Panicf("failed to migrate the database down: %v", err) } - c.logger.LogF(monitor.F{"migrations_number": n}).Info("migrated the database down") + c.logger.WithFields(logrus.Fields{"migrations_number": n}).Info("migrated the database down") } // Truncate purges records from the requested tables. @@ -56,7 +55,7 @@ func (c *Connection) CreateDB(dbName string) error { // fmt.Sprintf is used instead of query placeholders because postgres does not // handle them in schema-modifying queries. _, err = utilConn.DB.Exec(fmt.Sprintf("create database %s;", pq.QuoteIdentifier(dbName))) - c.logger.LogF(monitor.F{"db_name": dbName}).Info("created the database") + c.logger.WithFields(logrus.Fields{"db_name": dbName}).Info("created the database") return err } @@ -70,6 +69,6 @@ func (c *Connection) DropDB(dbName string) error { // fmt.Sprintf is used instead of query placeholders because postgres does not // handle them in schema-modifying queries. _, err = utilConn.DB.Exec(fmt.Sprintf("drop database %s;", pq.QuoteIdentifier(dbName))) - c.logger.LogF(monitor.F{"db_name": dbName}).Info("dropped the database") + c.logger.WithFields(logrus.Fields{"db_name": dbName}).Info("dropped the database") return err } diff --git a/internal/test/test.go b/internal/test/test.go index 8a675a6f..c000911a 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/lbryio/lbrytv/config" + "github.com/nsf/jsondiff" "github.com/ybbus/jsonrpc" ) diff --git a/internal/test/test_test.go b/internal/test/test_test.go index f8e441e2..ed092c5a 100644 --- a/internal/test/test_test.go +++ b/internal/test/test_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestMockRPCServer(t *testing.T) { +func TestMockHTTPServer(t *testing.T) { reqChan := ReqChan() rpcServer := MockHTTPServer(reqChan) defer rpcServer.Close() diff --git a/server/server.go b/server/server.go index a571573c..9799eeb1 100644 --- a/server/server.go +++ b/server/server.go @@ -40,10 +40,9 @@ func NewServer(address string, sdkRouter *sdkrouter.Router) *Server { stopWait: 15 * time.Second, stopChan: make(chan os.Signal), listener: &http.Server{ - Addr: address, - Handler: r, - // Can't have WriteTimeout set for streaming endpoints - WriteTimeout: 0, + Addr: address, + Handler: r, + WriteTimeout: 30 * time.Second, IdleTimeout: 0, ReadHeaderTimeout: 10 * time.Second, }, From 3b842f8378f1643c490e4dfa104eea4bde53cc59 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Mon, 20 Apr 2020 13:48:54 -0400 Subject: [PATCH 14/18] drop models from coveralls --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index f6faf3e6..648192d3 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ test_circleci: go get github.com/mattn/goveralls go run . db_migrate_up go test -covermode=count -coverprofile=coverage.out ./... - goveralls -coverprofile=coverage.out -service=circle-ci -repotoken $(COVERALLS_TOKEN) + goveralls -coverprofile=coverage.out -service=circle-ci -ignore=models/ -repotoken $(COVERALLS_TOKEN) release: GO111MODULE=on goreleaser --rm-dist From ebd1b78ed4bf8fed5762f33911e9c96ee64dc0fa Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Mon, 20 Apr 2020 15:47:52 -0400 Subject: [PATCH 15/18] tests for test.AssertJsonEqual --- config/config.go | 31 +++++++------------------------ internal/test/test_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/config/config.go b/config/config.go index 230516b4..9463d290 100644 --- a/config/config.go +++ b/config/config.go @@ -11,10 +11,9 @@ import ( "github.com/spf13/viper" ) -type ConfigWrapper struct { +type configWrapper struct { Viper *viper.Viper overridden map[string]interface{} - ReadDone bool } type DBConfig struct { @@ -27,7 +26,7 @@ const lbrynetServers = "LbrynetServers" const deprecatedLbrynet = "Lbrynet" var once sync.Once -var Config *ConfigWrapper +var Config *configWrapper // overriddenValues stores overridden v values // and is initialized as an empty map in the read method @@ -37,21 +36,21 @@ func init() { Config = GetConfig() } -func GetConfig() *ConfigWrapper { +func GetConfig() *configWrapper { once.Do(func() { Config = NewConfig() }) return Config } -func NewConfig() *ConfigWrapper { - c := &ConfigWrapper{} +func NewConfig() *configWrapper { + c := &configWrapper{} c.Init() c.Read() return c } -func (c *ConfigWrapper) Init() { +func (c *configWrapper) Init() { c.overridden = make(map[string]interface{}) c.Viper = viper.New() @@ -80,12 +79,11 @@ func (c *ConfigWrapper) Init() { c.Viper.AddConfigPath("$HOME/.lbrytv") } -func (c *ConfigWrapper) Read() { +func (c *configWrapper) Read() { err := c.Viper.ReadInConfig() if err != nil { panic(err) } - c.ReadDone = true } // IsProduction is true if we are running in a production environment @@ -193,21 +191,6 @@ func GetReflectorAddress() string { return Config.Viper.GetString("ReflectorAddress") } -// GetReflectorTimeout returns reflector TCP timeout in seconds. -func GetReflectorTimeout() int64 { - return Config.Viper.GetInt64("ReflectorTimeout") -} - -// GetRefractorAddress returns refractor address in the format of host:port. -func GetRefractorAddress() string { - return Config.Viper.GetString("RefractorAddress") -} - -// GetRefractorTimeout returns refractor TCP timeout in seconds. -func GetRefractorTimeout() int64 { - return Config.Viper.GetInt64("RefractorTimeout") -} - // ShouldLogResponses enables or disables full SDK responses logging func ShouldLogResponses() bool { return Config.Viper.GetBool("ShouldLogResponses") diff --git a/internal/test/test_test.go b/internal/test/test_test.go index ed092c5a..16fbcaad 100644 --- a/internal/test/test_test.go +++ b/internal/test/test_test.go @@ -43,3 +43,28 @@ func TestMockHTTPServer(t *testing.T) { require.NoError(t, err) assert.Equal(t, string(body), "ok") } + +func TestAssertJsonEqual(t *testing.T) { + + testCases := []struct { + a, b string + same bool + }{ + {"{}", "12", false}, + {"{}", "{}", true}, + {"{}", "", false}, + {`{"a":1,"b":2}`, `{"b":2,"a":1}`, true}, + } + + for i, tc := range testCases { + testT := &testing.T{} + same := AssertJsonEqual(testT, tc.a, tc.b) + if tc.same { + assert.True(t, same, "Case %d same", i) + assert.False(t, testT.Failed(), "Case %d failure", i) + } else { + assert.False(t, same, "Case %d same", i) + assert.True(t, testT.Failed(), "Case %d failure", i) + } + } +} From bf40725f0ebad1037f8bd86a9c48bdaa5b52fb38 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Tue, 21 Apr 2020 09:00:25 -0400 Subject: [PATCH 16/18] deduping --- api/routes.go | 3 ++- app/auth/auth.go | 2 +- app/proxy/client_test.go | 7 +++---- app/proxy/errors.go | 16 ++++++++++++++++ app/proxy/handlers.go | 8 +------- app/publish/publish.go | 20 ++++++++++---------- internal/monitor/middleware.go | 18 +++++++----------- internal/monitor/sentry.go | 14 -------------- 8 files changed, 40 insertions(+), 48 deletions(-) diff --git a/api/routes.go b/api/routes.go index 8633714e..e0d7ad3f 100644 --- a/api/routes.go +++ b/api/routes.go @@ -48,10 +48,11 @@ func InstallRoutes(r *mux.Router, sdkRouter *sdkrouter.Router) { v1Router.HandleFunc("/proxy", upHandler.Handle).MatcherFunc(upHandler.CanHandle) v1Router.HandleFunc("/proxy", proxy.Handle) v1Router.HandleFunc("/metric/ui", metrics.TrackUIMetric).Methods(http.MethodPost) + v1Router.HandleFunc("/status", status.GetStatus) internalRouter := r.PathPrefix("/internal").Subrouter() internalRouter.Handle("/metrics", promhttp.Handler()) - internalRouter.Handle("/status", middlewareStack(http.HandlerFunc(status.GetStatus))) + internalRouter.Handle("/status", middlewareStack(http.HandlerFunc(status.GetStatus))) // deprecated. moved to /api/v1/status internalRouter.HandleFunc("/whoami", status.WhoAMI) } diff --git a/app/auth/auth.go b/app/auth/auth.go index 4051211c..e7ee1e89 100644 --- a/app/auth/auth.go +++ b/app/auth/auth.go @@ -9,9 +9,9 @@ import ( "github.com/lbryio/lbrytv/internal/ip" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/models" - "github.com/sirupsen/logrus" "github.com/gorilla/mux" + "github.com/sirupsen/logrus" ) var logger = monitor.NewModuleLogger("auth") diff --git a/app/proxy/client_test.go b/app/proxy/client_test.go index 43178198..059a1191 100644 --- a/app/proxy/client_test.go +++ b/app/proxy/client_test.go @@ -13,7 +13,7 @@ import ( "github.com/ybbus/jsonrpc" ) -func TestClientCallDoesReloadWallet(t *testing.T) { +func TestClient_CallQueryWithRetry(t *testing.T) { rand.Seed(time.Now().UnixNano()) dummyUserID := rand.Intn(100) addr := test.RandServerAddress(t) @@ -27,13 +27,12 @@ func TestClientCallDoesReloadWallet(t *testing.T) { require.NoError(t, err) q.WalletID = sdkrouter.WalletID(dummyUserID) + // check that sdk loads the wallet and retries the query if the wallet was not initially loaded + c := NewCaller(addr, dummyUserID) r, err := c.callQueryWithRetry(q) - // err = json.Unmarshal(result, response) require.NoError(t, err) require.Nil(t, r.Error) - - // TODO: check that wallet is actually reloaded? what is this test even testing? } func TestClientCallDoesNotReloadWalletAfterOtherErrors(t *testing.T) { diff --git a/app/proxy/errors.go b/app/proxy/errors.go index e2a15122..5d6e4900 100644 --- a/app/proxy/errors.go +++ b/app/proxy/errors.go @@ -3,6 +3,10 @@ package proxy import ( "encoding/json" "errors" + "net/http" + + "github.com/lbryio/lbrytv/app/auth" + "github.com/lbryio/lbrytv/internal/responses" "github.com/ybbus/jsonrpc" ) @@ -52,3 +56,15 @@ func isJSONParseError(err error) bool { var e RPCError return err != nil && errors.As(err, &e) && e.code == rpcErrorCodeJSONParse } + +func EnsureAuthenticated(ar auth.Result, w http.ResponseWriter) bool { + if !ar.AuthAttempted() { + w.Write(NewAuthRequiredError(errors.New(responses.AuthRequiredErrorMessage)).JSON()) + return false + } + if !ar.Authenticated() { + w.Write(NewForbiddenError(ar.Err()).JSON()) + return false + } + return true +} diff --git a/app/proxy/handlers.go b/app/proxy/handlers.go index 79b00639..92eb67af 100644 --- a/app/proxy/handlers.go +++ b/app/proxy/handlers.go @@ -2,7 +2,6 @@ package proxy import ( "encoding/json" - "errors" "io/ioutil" "net/http" @@ -49,12 +48,7 @@ func Handle(w http.ResponseWriter, r *http.Request) { var sdkAddress string if MethodNeedsAuth(req.Method) { authResult := auth.FromRequest(r) - if !authResult.AuthAttempted() { - w.Write(NewAuthRequiredError(errors.New(responses.AuthRequiredErrorMessage)).JSON()) - return - } - if !authResult.Authenticated() { - w.Write(NewForbiddenError(authResult.Err()).JSON()) + if !EnsureAuthenticated(authResult, w) { return } userID = authResult.User().ID diff --git a/app/publish/publish.go b/app/publish/publish.go index e4bacc65..cbc85f88 100644 --- a/app/publish/publish.go +++ b/app/publish/publish.go @@ -12,10 +12,9 @@ import ( "github.com/lbryio/lbrytv/app/auth" "github.com/lbryio/lbrytv/app/proxy" "github.com/lbryio/lbrytv/internal/monitor" - "github.com/lbryio/lbrytv/internal/responses" - "github.com/sirupsen/logrus" "github.com/gorilla/mux" + "github.com/sirupsen/logrus" ) var logger = monitor.NewModuleLogger("publish") @@ -41,13 +40,7 @@ func (h Handler) Handle(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) authResult := auth.FromRequest(r) - - if !authResult.AuthAttempted() { - w.Write(proxy.NewAuthRequiredError(errors.New(responses.AuthRequiredErrorMessage)).JSON()) - return - } - if !authResult.Authenticated() { - w.Write(proxy.NewForbiddenError(authResult.Err()).JSON()) + if !proxy.EnsureAuthenticated(authResult, w) { return } if authResult.SDKAddress == "" { @@ -69,7 +62,14 @@ func (h Handler) Handle(w http.ResponseWriter, r *http.Request) { } }() - w.Write(publish(authResult.SDKAddress, f.Name(), authResult.User().ID, []byte(r.FormValue(jsonRPCFieldName)))) + res := publish( + authResult.SDKAddress, + f.Name(), + authResult.User().ID, + []byte(r.FormValue(jsonRPCFieldName)), + ) + + w.Write(res) } func publish(sdkAddress, filename string, userID int, rawQuery []byte) []byte { diff --git a/internal/monitor/middleware.go b/internal/monitor/middleware.go index 6a2fcefb..b95f68cc 100644 --- a/internal/monitor/middleware.go +++ b/internal/monitor/middleware.go @@ -86,21 +86,17 @@ func ErrorLoggingMiddleware(next http.Handler) http.Handler { }) } -func CaptureRequestError(err error, r *http.Request, w http.ResponseWriter, params ...map[string]interface{}) { - extra := logrus.Fields{} - if len(params) > 0 { - extra = params[0] +func CaptureRequestError(err error, r *http.Request, w http.ResponseWriter) { + fields := logrus.Fields{ + "method": r.Method, + "url": r.URL.Path, } - - extra["method"] = r.Method - extra["url"] = r.URL.Path - if lw, ok := w.(*loggingWriter); ok { - extra["status"] = fmt.Sprintf("%v", lw.Status) - extra["response"] = lw.ResponseSnippet + fields["status"] = fmt.Sprintf("%v", lw.Status) + fields["response"] = lw.ResponseSnippet } - httpLogger.WithFields(extra).Error(err) + httpLogger.WithFields(fields).Error(err) CaptureException(err) // if hub := sentry.GetHubFromContext(r.Context()); hub != nil { // hub.WithScope(func(scope *sentry.Scope) { diff --git a/internal/monitor/sentry.go b/internal/monitor/sentry.go index f2c85296..b3717b02 100644 --- a/internal/monitor/sentry.go +++ b/internal/monitor/sentry.go @@ -1,8 +1,6 @@ package monitor import ( - "fmt" - "github.com/lbryio/lbrytv/config" "github.com/lbryio/lbrytv/internal/responses" @@ -58,15 +56,3 @@ func CaptureException(err error, params ...map[string]string) { sentry.CaptureException(err) }) } - -// CaptureFailedQuery sends to Sentry details of a failed daemon call. -func CaptureFailedQuery(method string, query interface{}, errorResponse interface{}) { - CaptureException( - fmt.Errorf("daemon responded with an error when calling method %v", method), - map[string]string{ - "method": method, - "query": fmt.Sprintf("%v", query), - "response": fmt.Sprintf("%v", errorResponse), - }, - ) -} From 026a8f7904d0b872a0848dac57a5446165300a80 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Tue, 21 Apr 2020 09:25:16 -0400 Subject: [PATCH 17/18] only init sentry once --- api/routes.go | 2 +- app/proxy/processors.go | 3 +- app/sdkrouter/sdkrouter.go | 25 ++++++++-------- config/config.go | 5 ---- internal/monitor/module_logger.go | 4 +-- internal/monitor/monitor.go | 48 +++++++++++++++++-------------- internal/monitor/monitor_test.go | 4 +-- internal/monitor/sentry.go | 5 ++-- version/version.go | 22 ++++++-------- 9 files changed, 57 insertions(+), 61 deletions(-) diff --git a/api/routes.go b/api/routes.go index e0d7ad3f..0cb2f65f 100644 --- a/api/routes.go +++ b/api/routes.go @@ -33,7 +33,7 @@ func InstallRoutes(r *mux.Router, sdkRouter *sdkrouter.Router) { r.Use(methodTimer) r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { - http.Redirect(w, req, config.GetProjectURL(), http.StatusSeeOther) + w.Write([]byte("lbrytv api")) }) authProvider := auth.NewIAPIProvider(sdkRouter, config.GetInternalAPIHost()) diff --git a/app/proxy/processors.go b/app/proxy/processors.go index 8abc7370..4a02f820 100644 --- a/app/proxy/processors.go +++ b/app/proxy/processors.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/lbryio/lbrytv/config" - "github.com/lbryio/lbrytv/internal/monitor" ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" @@ -72,7 +71,7 @@ func responseProcessorFileList(response *jsonrpc.RPCResponse) error { } func responseProcessorAccountList(response *jsonrpc.RPCResponse, query *jsonrpc.RPCRequest) error { - monitor.Logger.WithFields(log.Fields{"params": query.Params}).Info("got account_list query") + logger.WithFields(log.Fields{"params": query.Params}).Info("got account_list query") if query.Params == nil { accounts := new(ljsonrpc.AccountListResponse) diff --git a/app/sdkrouter/sdkrouter.go b/app/sdkrouter/sdkrouter.go index 03c007ec..4ffec02d 100644 --- a/app/sdkrouter/sdkrouter.go +++ b/app/sdkrouter/sdkrouter.go @@ -10,7 +10,6 @@ import ( "github.com/lbryio/lbrytv/internal/metrics" "github.com/lbryio/lbrytv/internal/monitor" "github.com/lbryio/lbrytv/models" - "github.com/sirupsen/logrus" ljsonrpc "github.com/lbryio/lbry.go/v2/extras/jsonrpc" ) @@ -51,18 +50,18 @@ func New(servers map[string]string) *Router { func (r *Router) GetAll() []*models.LbrynetServer { r.reloadServersFromDB() - logger.WithFields(logrus.Fields{"lock": "mu"}).Trace("waiting for read lock in GetAll") + logger.Log().Trace("waiting for read lock in GetAll") r.mu.RLock() - logger.WithFields(logrus.Fields{"lock": "mu"}).Trace("got read lock in GetAll") + logger.Log().Trace("got read lock in GetAll") defer r.mu.RUnlock() return r.servers } func (r *Router) RandomServer() *models.LbrynetServer { r.reloadServersFromDB() - logger.WithFields(logrus.Fields{"lock": "mu"}).Trace("waiting for read lock in RandomServer") + logger.Log().Trace("waiting for read lock in RandomServer") r.mu.RLock() - logger.WithFields(logrus.Fields{"lock": "mu"}).Trace("got read lock in RandomServer") + logger.Log().Trace("got read lock in RandomServer") defer r.mu.RUnlock() return r.servers[rand.Intn(len(r.servers))] } @@ -91,9 +90,9 @@ func (r *Router) setServers(servers []*models.LbrynetServer) { // we do this partially to make sure that ids are assigned to servers more consistently, // and partially to make tests consistent (since Go maps are not ordered) sort.Slice(servers, func(i, j int) bool { return servers[i].Name < servers[j].Name }) - logger.WithFields(logrus.Fields{"lock": "mu"}).Trace("waiting for write lock in setServers") + logger.Log().Trace("waiting for write lock in setServers") r.mu.Lock() - logger.WithFields(logrus.Fields{"lock": "mu"}).Trace("got write lock in setServers") + logger.Log().Trace("got write lock in setServers") defer r.mu.Unlock() r.servers = servers logger.Log().Debugf("updated server list to %d servers", len(r.servers)) @@ -120,17 +119,17 @@ func (r *Router) updateLoadAndMetrics() { walletList, err := ljsonrpc.NewClient(server.Address).WalletList("", 1, 1) if err != nil { logger.Log().Errorf("lbrynet instance %s is not responding: %v", server.Address, err) - logger.WithFields(logrus.Fields{"lock": "loadMu"}).Trace("waiting for write lock in updateLoadAndMetrics 1") + logger.Log().Trace("waiting for write lock in updateLoadAndMetrics 1") r.loadMu.Lock() - logger.WithFields(logrus.Fields{"lock": "loadMu"}).Trace("got write lock in updateLoadAndMetrics 1") + logger.Log().Trace("got write lock in updateLoadAndMetrics 1") delete(r.load, server) r.loadMu.Unlock() metric.Set(-1.0) // TODO: maybe mark this instance as unresponsive so new users are assigned to other instances } else { - logger.WithFields(logrus.Fields{"lock": "loadMu"}).Trace("waiting for write lock in updateLoadAndMetrics 2") + logger.Log().Trace("waiting for write lock in updateLoadAndMetrics 2") r.loadMu.Lock() - logger.WithFields(logrus.Fields{"lock": "loadMu"}).Trace("got write lock in updateLoadAndMetrics 2") + logger.Log().Trace("got write lock in updateLoadAndMetrics 2") r.load[server] = walletList.TotalPages r.loadMu.Unlock() metric.Set(float64(walletList.TotalPages)) @@ -145,9 +144,9 @@ func (r *Router) LeastLoaded() *models.LbrynetServer { var best *models.LbrynetServer var min uint64 - logger.WithFields(logrus.Fields{"lock": "loadMu"}).Trace("waiting for read lock in LeastLoaded") + logger.Log().Trace("waiting for read lock in LeastLoaded") r.loadMu.RLock() - logger.WithFields(logrus.Fields{"lock": "loadMu"}).Trace("got read lock in LeastLoaded") + logger.Log().Trace("got read lock in LeastLoaded") defer r.loadMu.RUnlock() if len(r.load) == 0 { diff --git a/config/config.go b/config/config.go index 9463d290..3f4f6a98 100644 --- a/config/config.go +++ b/config/config.go @@ -170,11 +170,6 @@ func GetSentryDSN() string { return Config.Viper.GetString("SentryDSN") } -// GetProjectURL returns publicly accessible URL for the project -func GetProjectURL() string { - return Config.Viper.GetString("ProjectURL") -} - // GetPublishSourceDir returns directory for storing published files before they're uploaded to lbrynet. // The directory needs to be accessed by the running SDK instance. func GetPublishSourceDir() string { diff --git a/internal/monitor/module_logger.go b/internal/monitor/module_logger.go index e5fb7c71..39ea75f4 100644 --- a/internal/monitor/module_logger.go +++ b/internal/monitor/module_logger.go @@ -18,7 +18,7 @@ type ModuleLogger struct { // for later `Log()` calls. func NewModuleLogger(moduleName string) ModuleLogger { logger := logrus.New() - configureLogger(logger) + configureLogLevelAndFormat(logger) return ModuleLogger{ moduleName: moduleName, Logger: logger, @@ -33,7 +33,7 @@ func (l ModuleLogger) WithFields(fields logrus.Fields) *logrus.Entry { fields["module"] = l.moduleName if v, ok := fields[TokenF]; ok && v != "" && config.IsProduction() { - fields[TokenF] = ValueMask + fields[TokenF] = valueMask } return l.Logger.WithFields(fields) } diff --git a/internal/monitor/monitor.go b/internal/monitor/monitor.go index 6ddf01b1..45df026c 100644 --- a/internal/monitor/monitor.go +++ b/internal/monitor/monitor.go @@ -7,43 +7,49 @@ import ( "github.com/sirupsen/logrus" ) -// Logger is a global instance of logrus object. -var Logger = logrus.New() +var logger = NewModuleLogger("monitor") -// TokenF is a token field name that will be stripped from logs in production mode. -const TokenF = "token" - -// ValueMask is what replaces sensitive fields contents in logs. -const ValueMask = "****" +const ( + // TokenF is a token field name that will be stripped from logs in production mode. + TokenF = "token" + // valueMask is what replaces sensitive fields contents in logs. + valueMask = "****" +) var jsonFormatter = logrus.JSONFormatter{DisableTimestamp: true} var textFormatter = logrus.TextFormatter{FullTimestamp: true, TimestampFormat: "15:04:05"} // init magic is needed so logging is set up without calling it in every package explicitly func init() { - configureLogger(logrus.StandardLogger()) -} + l := logrus.StandardLogger() + configureLogLevelAndFormat(l) -// configureLogger sets a few parameters for the logging subsystem. -func configureLogger(l *logrus.Logger) { - var mode string + l.WithFields( + version.BuildInfo(), + ).WithFields(logrus.Fields{ + "mode": mode(), + "logLevel": l.Level, + }).Infof("standard logger configured") + configureSentry(version.GetDevVersion(), mode()) +} + +func mode() string { if config.IsProduction() { - mode = "production" + return "production" + } else { + return "develop" + } +} +func configureLogLevelAndFormat(l *logrus.Logger) { + if config.IsProduction() { l.SetLevel(logrus.InfoLevel) l.SetFormatter(&jsonFormatter) } else { - mode = "develop" - l.SetLevel(logrus.TraceLevel) l.SetFormatter(&textFormatter) } - - l.Infof("%s, running in %s mode", version.GetFullBuildName(), mode) - l.Infof("logging initialized (loglevel=%s)", l.Level) - - configureSentry(version.GetDevVersion(), mode) } // LogSuccessfulQuery takes a remote method name, execution time and params and logs it @@ -56,5 +62,5 @@ func LogSuccessfulQuery(method string, time float64, params interface{}, respons if config.ShouldLogResponses() { fields["response"] = response } - Logger.WithFields(fields).Info("call processed") + logger.WithFields(fields).Info("call processed") } diff --git a/internal/monitor/monitor_test.go b/internal/monitor/monitor_test.go index 9b5f8275..36c57d26 100644 --- a/internal/monitor/monitor_test.go +++ b/internal/monitor/monitor_test.go @@ -12,7 +12,7 @@ import ( ) func TestLogSuccessfulQuery(t *testing.T) { - hook := test.NewLocal(Logger) + hook := test.NewLocal(logger.Logger) config.Override("ShouldLogResponses", false) defer config.RestoreOverridden() @@ -149,7 +149,7 @@ func TestModuleLoggerMasksTokens(t *testing.T) { l.WithFields(logrus.Fields{"token": "SecRetT0Ken", "email": "abc@abc.com"}).Info("something happened") require.Equal(t, "abc@abc.com", hook.LastEntry().Data["email"]) - require.Equal(t, ValueMask, hook.LastEntry().Data["token"]) + require.Equal(t, valueMask, hook.LastEntry().Data["token"]) hook.Reset() } diff --git a/internal/monitor/sentry.go b/internal/monitor/sentry.go index b3717b02..4806cafe 100644 --- a/internal/monitor/sentry.go +++ b/internal/monitor/sentry.go @@ -14,6 +14,7 @@ var IgnoredExceptions = []string{ func configureSentry(release, env string) { dsn := config.GetSentryDSN() if dsn == "" { + logger.Log().Info("sentry disabled (no DNS configured)") return } @@ -34,9 +35,9 @@ func configureSentry(release, env string) { }, }) if err != nil { - Logger.Errorf("sentry initialization failed: %v", err) + logger.Log().Errorf("sentry initialization failed: %v", err) } else { - Logger.Info("Sentry initialized") + logger.Log().Info("sentry initialized") } } diff --git a/version/version.go b/version/version.go index fd5ff0bc..fc53ee6e 100644 --- a/version/version.go +++ b/version/version.go @@ -3,18 +3,11 @@ package version import "fmt" var ( - version = "unknown" - commit = "unknown" - date = "unknown" + version = "unknown" + commit = "unknown" + buildDate = "unknown" ) -var appName = "lbrytv" - -// GetAppName returns main application name -func GetAppName() string { - return appName -} - // GetVersion returns current application version func GetVersion() string { return version @@ -28,7 +21,10 @@ func GetDevVersion() string { return "unknown" } -// GetFullBuildName returns current app version, commit and build time -func GetFullBuildName() string { - return fmt.Sprintf("%v %v, commit %v, built at %v", GetAppName(), GetVersion(), commit, date) +func BuildInfo() map[string]interface{} { + return map[string]interface{}{ + "buildVersion": GetVersion(), + "buildCommit": commit, + "buildDate": buildDate, + } } From 601d8e299dbfd60534fd27b778c1eb9f81d16e7e Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Wed, 22 Apr 2020 10:04:09 -0400 Subject: [PATCH 18/18] save wallet IDs in db for BC. revert this later --- app/wallet/wallet.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/app/wallet/wallet.go b/app/wallet/wallet.go index 4dae75eb..24fb8d7e 100644 --- a/app/wallet/wallet.go +++ b/app/wallet/wallet.go @@ -85,6 +85,14 @@ func assignSDKServerToUser(user *models.User, router *sdkrouter.Router, log *log if err != nil { return err } + if user.ID > 0 { + // retain BC for now. can remove this after sdk selection refactor has shown itself solid + user.WalletID = sdkrouter.WalletID(user.ID) + _, err := user.UpdateG(boil.Infer()) + if err != nil { + return err + } + } } else { // THIS SERVER CAME FROM A CONFIG FILE (prolly during testing) // TODO: handle this case better