diff --git a/services/horizon/internal/actions/submit_transaction.go b/services/horizon/internal/actions/submit_transaction.go index 97b3dd3580..8b08fc4638 100644 --- a/services/horizon/internal/actions/submit_transaction.go +++ b/services/horizon/internal/actions/submit_transaction.go @@ -78,7 +78,7 @@ func (handler SubmitTransactionHandler) response(r *http.Request, info envelopeI } if result.Err == txsub.ErrCanceled { - return nil, &hProblem.Timeout + return nil, &hProblem.ClientDisconnected } switch err := result.Err.(type) { @@ -153,6 +153,6 @@ func (handler SubmitTransactionHandler) GetResource(w HeaderWriter, r *http.Requ case result := <-submission: return handler.response(r, info, result) case <-r.Context().Done(): - return nil, &hProblem.Timeout + return nil, &hProblem.ClientDisconnected } } diff --git a/services/horizon/internal/actions_transaction_test.go b/services/horizon/internal/actions_transaction_test.go index 605a0a8f69..1a69d8cf7e 100644 --- a/services/horizon/internal/actions_transaction_test.go +++ b/services/horizon/internal/actions_transaction_test.go @@ -1,7 +1,9 @@ package horizon import ( + "context" "encoding/json" + "net/http" "net/url" "testing" @@ -289,6 +291,62 @@ func TestTransactionActions_Post(t *testing.T) { ht.Assert.Equal(200, w.Code) } + + + +func TestTransactionActions_Post_ClientDisconnect(t *testing.T) { + ht := StartHTTPTest(t, "base") + defer ht.Finish() + + // Pass Synced check + ht.App.coreState.SetState(corestate.State{Synced: true}) + + tx := xdr.TransactionEnvelope{ + Type: xdr.EnvelopeTypeEnvelopeTypeTxV0, + V0: &xdr.TransactionV0Envelope{ + Tx: xdr.TransactionV0{ + SourceAccountEd25519: *xdr.MustAddress("GBRPYHIL2CI3FNQ4BXLFMNDLFJUNPU2HY3ZMFSHONUCEOASW7QC7OX2H").Ed25519, + Fee: 100, + SeqNum: 1, + Operations: []xdr.Operation{ + { + Body: xdr.OperationBody{ + Type: xdr.OperationTypeCreateAccount, + CreateAccountOp: &xdr.CreateAccountOp{ + Destination: xdr.MustAddress("GCXKG6RN4ONIEPCMNFB732A436Z5PNDSRLGWK7GBLCMQLIFO4S7EYWVU"), + StartingBalance: 1000000000, + }, + }, + }, + }, + }, + Signatures: []xdr.DecoratedSignature{ + { + Hint: xdr.SignatureHint{86, 252, 5, 247}, + Signature: xdr.Signature{131, 206, 171, 228, 64, 20, 40, 52, 2, 98, 124, 244, 87, 14, 130, 225, 190, 220, 156, 79, 121, 69, 60, 36, 57, 214, 9, 29, 176, 81, 218, 4, 213, 176, 211, 148, 191, 86, 21, 180, 94, 9, 43, 208, 32, 79, 19, 131, 90, 21, 93, 138, 153, 203, 55, 103, 2, 230, 137, 190, 19, 70, 179, 11}, + }, + }, + }, + } + + txStr, err := xdr.MarshalBase64(tx) + assert.NoError(t, err) + form := url.Values{"tx": []string{txStr}} + + // existing transaction + + w := ht.Post("/transactions", form, + func (req *http.Request) (*http.Request){ + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) + cancel() + return req + }) + + ht.Assert.Equal(499, w.Code) +} + + func TestTransactionActions_PostSuccessful(t *testing.T) { ht := StartHTTPTest(t, "failed_transactions") defer ht.Finish() diff --git a/services/horizon/internal/app_test.go b/services/horizon/internal/app_test.go index d6a2194d1d..5abf2295d1 100644 --- a/services/horizon/internal/app_test.go +++ b/services/horizon/internal/app_test.go @@ -20,8 +20,9 @@ func TestGenericHTTPFeatures(t *testing.T) { ht.Assert.Empty(w.HeaderMap.Get("Access-Control-Allow-Origin")) } - w = ht.Get("/", func(r *http.Request) { + w = ht.Get("/", func(r *http.Request) (*http.Request){ r.Header.Set("Origin", "somewhere.com") + return r }) if ht.Assert.Equal(200, w.Code) { diff --git a/services/horizon/internal/httpt_test.go b/services/horizon/internal/httpt_test.go index a0ff96f521..a93c718d41 100644 --- a/services/horizon/internal/httpt_test.go +++ b/services/horizon/internal/httpt_test.go @@ -68,7 +68,7 @@ func StartHTTPTestWithoutScenario(t *testing.T) *HTTPT { // Get delegates to the test's request helper func (ht *HTTPT) Get( path string, - fn ...func(*http.Request), + fn ...func(*http.Request)(*http.Request), ) *httptest.ResponseRecorder { return ht.RH.Get(path, fn...) } @@ -77,7 +77,7 @@ func (ht *HTTPT) Get( func (ht *HTTPT) GetWithParams( path string, queryParams url.Values, - fn ...func(*http.Request), + fn ...func(*http.Request)(*http.Request), ) *httptest.ResponseRecorder { return ht.RH.Get(path+"?"+queryParams.Encode(), fn...) } @@ -93,7 +93,7 @@ func (ht *HTTPT) Finish() { func (ht *HTTPT) Post( path string, form url.Values, - mods ...func(*http.Request), + mods ...func(*http.Request)(*http.Request), ) *httptest.ResponseRecorder { return ht.RH.Post(path, form, mods...) } diff --git a/services/horizon/internal/httpx/server.go b/services/horizon/internal/httpx/server.go index a5a51da5eb..85eaa5bc96 100644 --- a/services/horizon/internal/httpx/server.go +++ b/services/horizon/internal/httpx/server.go @@ -55,7 +55,7 @@ func init() { problem.RegisterError(db2.ErrInvalidOrder, problem.BadRequest) problem.RegisterError(sse.ErrRateLimited, hProblem.RateLimitExceeded) problem.RegisterError(context.DeadlineExceeded, hProblem.Timeout) - problem.RegisterError(context.Canceled, hProblem.ServiceUnavailable) + problem.RegisterError(context.Canceled, hProblem.ClientDisconnected) problem.RegisterError(db.ErrCancelled, hProblem.ServiceUnavailable) problem.RegisterError(db.ErrConflictWithRecovery, hProblem.ServiceUnavailable) problem.RegisterError(db.ErrBadConnection, hProblem.ServiceUnavailable) diff --git a/services/horizon/internal/middleware_test.go b/services/horizon/internal/middleware_test.go index b411b23e3d..6567e5347c 100644 --- a/services/horizon/internal/middleware_test.go +++ b/services/horizon/internal/middleware_test.go @@ -27,15 +27,17 @@ import ( "github.com/stellar/go/xdr" ) -func requestHelperRemoteAddr(ip string) func(r *http.Request) { - return func(r *http.Request) { +func requestHelperRemoteAddr(ip string) func(r *http.Request)(*http.Request) { + return func(r *http.Request) (*http.Request) { r.RemoteAddr = ip + return r } } -func requestHelperXFF(xff string) func(r *http.Request) { - return func(r *http.Request) { +func requestHelperXFF(xff string) func(r *http.Request) (*http.Request) { + return func(r *http.Request) (*http.Request) { r.Header.Set("X-Forwarded-For", xff) + return r } } @@ -306,6 +308,35 @@ func TestStateMiddleware(t *testing.T) { } } +func TestClientDisconnect(t *testing.T) { + tt := test.Start(t) + defer tt.Finish() + test.ResetHorizonDB(t, tt.HorizonDB) + + request, err := http.NewRequest("GET", "http://localhost/", nil) + tt.Assert.NoError(err) + + endpoint := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + + stateMiddleware := &httpx.StateMiddleware{ + HorizonSession: tt.HorizonSession(), + NoStateVerification: true, + } + handler := chi.NewRouter() + handler.With(stateMiddleware.Wrap).MethodFunc("GET", "/", endpoint) + w := httptest.NewRecorder() + + ctx, cancel := context.WithCancel(request.Context()) + request = request.WithContext(ctx) + // cancel invocation simulates client disconnect in the context + cancel() + + handler.ServeHTTP(w, request) + tt.Assert.Equal(499, w.Code) +} + func TestCheckHistoryStaleMiddleware(t *testing.T) { tt := test.Start(t) defer tt.Finish() diff --git a/services/horizon/internal/render/problem/problem.go b/services/horizon/internal/render/problem/problem.go index 046c3b6604..f099c54c5e 100644 --- a/services/horizon/internal/render/problem/problem.go +++ b/services/horizon/internal/render/problem/problem.go @@ -8,6 +8,17 @@ import ( // Well-known and reused problems below: var ( + + // ClientDisconnected, represented by a non-standard HTTP status code of 499, which was introduced by + // nginix.org(https://www.nginx.com/resources/wiki/extending/api/http/) as a way to capture this state. Use it as a shortcut + // in your actions. + ClientDisconnected = problem.P{ + Type: "client_disconnected", + Title: "Client Disconnected", + Status: 499, + Detail: "The client has closed the connection.", + } + // ServiceUnavailable is a well-known problem type. Use it as a shortcut // in your actions. ServiceUnavailable = problem.P{ diff --git a/services/horizon/internal/render/problem/problem_test.go b/services/horizon/internal/render/problem/problem_test.go index d193079f41..707c75b658 100644 --- a/services/horizon/internal/render/problem/problem_test.go +++ b/services/horizon/internal/render/problem/problem_test.go @@ -25,6 +25,7 @@ func TestCommonProblems(t *testing.T) { }{ {"NotFound", problem.NotFound, 404}, {"RateLimitExceeded", RateLimitExceeded, 429}, + {"ClientDisconneted", ClientDisconnected, 499}, } for _, tc := range testCases { diff --git a/services/horizon/internal/test/http.go b/services/horizon/internal/test/http.go index 2e42e681ee..0524ae5ed2 100644 --- a/services/horizon/internal/test/http.go +++ b/services/horizon/internal/test/http.go @@ -10,20 +10,22 @@ import ( ) type RequestHelper interface { - Get(string, ...func(*http.Request)) *httptest.ResponseRecorder - Post(string, url.Values, ...func(*http.Request)) *httptest.ResponseRecorder + Get(string, ...func(*http.Request)(*http.Request)) *httptest.ResponseRecorder + Post(string, url.Values, ...func(*http.Request)(*http.Request)) *httptest.ResponseRecorder } type requestHelper struct { router *chi.Mux } -func RequestHelperRaw(r *http.Request) { +func RequestHelperRaw(r *http.Request) (*http.Request) { r.Header.Set("Accept", "application/octet-stream") + return r } -func RequestHelperStreaming(r *http.Request) { +func RequestHelperStreaming(r *http.Request) (*http.Request){ r.Header.Set("Accept", "text/event-stream") + return r } func NewRequestHelper(router *chi.Mux) RequestHelper { @@ -32,7 +34,7 @@ func NewRequestHelper(router *chi.Mux) RequestHelper { func (rh *requestHelper) Get( path string, - mods ...func(*http.Request), + mods ...func(*http.Request)(*http.Request), ) *httptest.ResponseRecorder { req, _ := http.NewRequest("GET", path, nil) @@ -42,7 +44,7 @@ func (rh *requestHelper) Get( func (rh *requestHelper) Post( path string, form url.Values, - mods ...func(*http.Request), + mods ...func(*http.Request)(*http.Request), ) *httptest.ResponseRecorder { body := strings.NewReader(form.Encode()) @@ -53,13 +55,13 @@ func (rh *requestHelper) Post( func (rh *requestHelper) Execute( req *http.Request, - requestModFns []func(*http.Request), + requestModFns []func(*http.Request)(*http.Request), ) *httptest.ResponseRecorder { req.RemoteAddr = "127.0.0.1" req.Host = "localhost" for _, fn := range requestModFns { - fn(req) + req = fn(req) } w := httptest.NewRecorder()