diff --git a/database/challenge.go b/database/challenge.go index f63f8650..d6a83c3f 100644 --- a/database/challenge.go +++ b/database/challenge.go @@ -247,14 +247,14 @@ func (cr *ChallengeResponse) LoadFromReader(r io.Reader) error { // LoadString loads a PubKey from its hex-encoded string form. func (pk *PubKey) LoadString(s string) error { - bb, err := hex.DecodeString(s) + b, err := hex.DecodeString(s) if err != nil { return errors.AddContext(err, ErrInvalidPublicKey.Error()) } - if len(bb) != PubKeySize { + if len(b) != PubKeySize { return ErrInvalidPublicKey } - *pk = bb[:] + *pk = b[:] return nil } diff --git a/test/api/api_test.go b/test/api/api_test.go index 0cf7e0fd..f7356a3c 100644 --- a/test/api/api_test.go +++ b/test/api/api_test.go @@ -24,7 +24,7 @@ import ( func TestWithDBSession(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } diff --git a/test/api/apikeys_test.go b/test/api/apikeys_test.go index bbf545f9..8aac6925 100644 --- a/test/api/apikeys_test.go +++ b/test/api/apikeys_test.go @@ -345,22 +345,9 @@ func testAPIKeysAcceptance(t *testing.T, at *test.AccountsTester) { } for _, tt := range tests { - switch tt.verb { - case http.MethodGet: - r, err = at.Request(http.MethodGet, tt.endpoint, nil, nil, nil, nil) - case http.MethodPost: - r, err = at.Request(http.MethodPost, tt.endpoint, nil, nil, nil, nil) - case http.MethodPut: - r, err = at.Request(http.MethodPut, tt.endpoint, nil, nil, nil, nil) - case http.MethodPatch: - r, err = at.Request(http.MethodPatch, tt.endpoint, nil, nil, nil, nil) - case http.MethodDelete: - r, err = at.Request(http.MethodDelete, tt.endpoint, nil, nil, nil, nil) - default: - t.Fatalf("Invalid verb: %+v", tt) - } + r, err = at.Request(tt.verb, tt.endpoint, nil, nil, nil, nil) if err == nil || r.StatusCode != http.StatusUnauthorized || !strings.Contains(err.Error(), api.ErrAPIKeyNotAllowed.Error()) { - t.Fatalf("Expected error '%s' with status %d, got '%s' with status %d. Endpoint %s %s", api.ErrAPIKeyNotAllowed, http.StatusUnauthorized, err, r.StatusCode, tt.verb, tt.endpoint) + t.Errorf("Expected error '%s' with status %d, got '%s' with status %d. Endpoint %s %s", api.ErrAPIKeyNotAllowed, http.StatusUnauthorized, err, r.StatusCode, tt.verb, tt.endpoint) } } } diff --git a/test/api/challenge_test.go b/test/api/challenge_test.go index e0f9b0f9..be8315b5 100644 --- a/test/api/challenge_test.go +++ b/test/api/challenge_test.go @@ -198,9 +198,9 @@ func testUserAddPubKey(t *testing.T, at *test.AccountsTester) { // Try to solve the challenge while logged in as a different user. // NOTE: This will consume the challenge and the user will need to request // a new one. - r, bb, err := at.UserPOST(name+"_user3@siasky.net", name+"_pass") + r, b, err := at.UserPOST(name+"_user3@siasky.net", name+"_pass") if err != nil || r.StatusCode != http.StatusOK { - t.Fatal(r.Status, err, string(bb)) + t.Fatal(r.Status, err, string(b)) } at.SetCookie(test.ExtractCookie(r)) _, status, err = at.UserPubkeyRegisterPOST(response, ed25519.Sign(sk[:], response)) diff --git a/test/api/upload_test.go b/test/api/upload_test.go index 5b15b077..98e34e07 100644 --- a/test/api/upload_test.go +++ b/test/api/upload_test.go @@ -16,12 +16,12 @@ func testUploadInfo(t *testing.T, at *test.AccountsTester) { name2 := name + "2" email := name + "@siasky.net" email2 := name2 + "@siasky.net" - r, _, err := at.CreateUserPost(email, name+"_pass") + r, _, err := at.UserPOST(email, name+"_pass") if err != nil { t.Fatal(err) } c1 := test.ExtractCookie(r) - r, _, err = at.CreateUserPost(email2, name2+"_pass") + r, _, err = at.UserPOST(email2, name2+"_pass") if err != nil { t.Fatal(err) } diff --git a/test/database/apikeys_test.go b/test/database/apikeys_test.go index c252240c..3060d361 100644 --- a/test/database/apikeys_test.go +++ b/test/database/apikeys_test.go @@ -12,7 +12,7 @@ import ( func TestAPIKeys(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } diff --git a/test/database/challenge_test.go b/test/database/challenge_test.go index 1ae6db13..83557636 100644 --- a/test/database/challenge_test.go +++ b/test/database/challenge_test.go @@ -20,7 +20,7 @@ import ( func TestValidateChallengeResponse(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } @@ -128,7 +128,7 @@ func TestValidateChallengeResponse(t *testing.T) { func TestUnconfirmedUserUpdate(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } diff --git a/test/database/configuration_test.go b/test/database/configuration_test.go index 83bae233..45de7439 100644 --- a/test/database/configuration_test.go +++ b/test/database/configuration_test.go @@ -14,7 +14,7 @@ import ( func TestConfiguration(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } diff --git a/test/database/upload_test.go b/test/database/upload_test.go index 6d47bdc3..5bbc0ccb 100644 --- a/test/database/upload_test.go +++ b/test/database/upload_test.go @@ -15,7 +15,7 @@ import ( func TestUploadsByUser(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } @@ -143,7 +143,7 @@ func TestUploadsByUser(t *testing.T) { func TestUnpinUploads(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } diff --git a/test/database/user_test.go b/test/database/user_test.go index b1d1e7a5..1926f1f6 100644 --- a/test/database/user_test.go +++ b/test/database/user_test.go @@ -22,7 +22,7 @@ import ( func TestUserByEmail(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } @@ -65,7 +65,7 @@ func TestUserByEmail(t *testing.T) { func TestUserByID(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } @@ -107,7 +107,7 @@ func TestUserByID(t *testing.T) { func TestUserByPubKey(t *testing.T) { ctx := context.Background() name := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, name, nil) + db, err := test.NewDatabase(ctx, name) if err != nil { t.Fatal(err) } @@ -158,7 +158,7 @@ func TestUserByPubKey(t *testing.T) { func TestUserByStripeID(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } @@ -202,7 +202,7 @@ func TestUserByStripeID(t *testing.T) { func TestUserBySub(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } @@ -244,7 +244,7 @@ func TestUserBySub(t *testing.T) { func TestUserConfirmEmail(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal("Failed to connect to the DB:", err) } @@ -282,7 +282,7 @@ func TestUserConfirmEmail(t *testing.T) { func TestUserCreate(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } @@ -365,7 +365,7 @@ func TestUserCreateEmailConfirmation(t *testing.T) { func TestUserDelete(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } @@ -400,7 +400,7 @@ func TestUserDelete(t *testing.T) { func TestUserSave(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } @@ -446,7 +446,7 @@ func TestUserSave(t *testing.T) { func TestUserSetStripeID(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } @@ -564,7 +564,7 @@ func TestUserPubKey(t *testing.T) { func TestUserSetTier(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } @@ -596,7 +596,7 @@ func TestUserSetTier(t *testing.T) { func TestUserStats(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } diff --git a/test/email/sender_test.go b/test/email/sender_test.go index 6056d702..83d30bff 100644 --- a/test/email/sender_test.go +++ b/test/email/sender_test.go @@ -23,7 +23,7 @@ func TestSender(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() dbName := test.DBNameForTest(t.Name()) - db, err := test.NewDatabase(ctx, dbName, nil) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } @@ -89,8 +89,7 @@ func TestSender(t *testing.T) { func TestContendingSenders(t *testing.T) { ctx := context.Background() dbName := test.DBNameForTest(t.Name()) - logger := logrus.New() - db, err := test.NewDatabase(ctx, dbName, logger) + db, err := test.NewDatabase(ctx, dbName) if err != nil { t.Fatal(err) } @@ -123,7 +122,7 @@ func TestContendingSenders(t *testing.T) { // messages from the DB and "send" them. It will stop doing that when it // reaches two executions that fail to send any messages. sender := func(serverID string) { - s, err := email.NewSender(ctx, db, logger, &test.DependencySkipSendingEmails{}, test.FauxEmailURI) + s, err := email.NewSender(ctx, db, test.NewDiscardLogger(), &test.DependencySkipSendingEmails{}, test.FauxEmailURI) if err != nil { t.Fatal(err) } diff --git a/test/tester.go b/test/tester.go index 341e1169..550f0fe6 100644 --- a/test/tester.go +++ b/test/tester.go @@ -44,20 +44,6 @@ type ( } ) -// CleanName sanitizes the input for all kinds of unwanted characters and -// replaces those with underscores. -// See https://docs.mongodb.com/manual/reference/limits/#naming-restrictions -func CleanName(s string) string { - re := regexp.MustCompile(`[/\\.\s"$*<>:|?]`) - cleanDBName := re.ReplaceAllString(s, "_") - // 64 characters is MongoDB's limit on database names. - // See https://docs.mongodb.com/manual/reference/limits/#mongodb-limit-Length-of-Database-Names - if len(cleanDBName) > 64 { - cleanDBName = cleanDBName[:64] - } - return cleanDBName -} - // ExtractCookie is a helper method which extracts the login cookie from a // response, so we can use it with future requests while testing. func ExtractCookie(r *http.Response) *http.Cookie { @@ -70,16 +56,15 @@ func ExtractCookie(r *http.Response) *http.Cookie { } // NewDatabase returns a new DB connection based on the passed parameters. -func NewDatabase(ctx context.Context, dbName string, logger *logrus.Logger) (*database.DB, error) { - return database.NewCustomDB(ctx, CleanName(dbName), DBTestCredentials(), logger) +func NewDatabase(ctx context.Context, dbName string) (*database.DB, error) { + return database.NewCustomDB(ctx, SanitizeName(dbName), DBTestCredentials(), NewDiscardLogger()) } // NewAccountsTester creates and starts a new AccountsTester service. // Use the Close method for a graceful shutdown. func NewAccountsTester(dbName string) (*AccountsTester, error) { ctx := context.Background() - logger := logrus.New() - logger.Out = ioutil.Discard + logger := NewDiscardLogger() // Initialise the environment. jwt.PortalName = testPortalAddr @@ -90,7 +75,7 @@ func NewAccountsTester(dbName string) (*AccountsTester, error) { } // Connect to the database. - db, err := NewDatabase(ctx, dbName, logger) + db, err := NewDatabase(ctx, dbName) if err != nil { return nil, errors.AddContext(err, "failed to connect to the DB") } @@ -147,6 +132,27 @@ func NewAccountsTester(dbName string) (*AccountsTester, error) { return at, nil } +// NewDiscardLogger returns a new logger that sends all output to ioutil.Discard. +func NewDiscardLogger() *logrus.Logger { + logger := logrus.New() + logger.Out = ioutil.Discard + return logger +} + +// SanitizeName sanitizes the input for all kinds of unwanted characters and +// replaces those with underscores. +// See https://docs.mongodb.com/manual/reference/limits/#naming-restrictions +func SanitizeName(s string) string { + re := regexp.MustCompile(`[/\\.\s"$*<>:|?]`) + cleanDBName := re.ReplaceAllString(s, "_") + // 64 characters is MongoDB's limit on database names. + // See https://docs.mongodb.com/manual/reference/limits/#mongodb-limit-Length-of-Database-Names + if len(cleanDBName) > 64 { + cleanDBName = cleanDBName[:64] + } + return cleanDBName +} + // ClearCredentials removes any credentials stored by this tester, such as a // cookie, token, etc. func (at *AccountsTester) ClearCredentials() { @@ -436,7 +442,7 @@ func (at *AccountsTester) UserPOST(emailAddr, password string) (*http.Response, // // NOTE: The Body of the returned response is already read and closed. func (at *AccountsTester) UserPUT(email, password, stipeID string) (api.UserGET, int, error) { - bb, err := json.Marshal(map[string]string{ + b, err := json.Marshal(map[string]string{ "email": email, "password": password, "stripeCustomerId": stipeID, @@ -445,7 +451,7 @@ func (at *AccountsTester) UserPUT(email, password, stipeID string) (api.UserGET, return api.UserGET{}, http.StatusBadRequest, err } var resp api.UserGET - r, err := at.Request(http.MethodPut, "/user", nil, bb, nil, &resp) + r, err := at.Request(http.MethodPut, "/user", nil, b, nil, &resp) return resp, r.StatusCode, err } @@ -485,32 +491,32 @@ func (at *AccountsTester) UserAPIKeysLIST() ([]api.APIKeyResponse, int, error) { // UserAPIKeysPOST performs a `POST /user/apikeys` Request. func (at *AccountsTester) UserAPIKeysPOST(body api.APIKeyPOST) (api.APIKeyResponseWithKey, int, error) { - bb, err := json.Marshal(body) + b, err := json.Marshal(body) if err != nil { return api.APIKeyResponseWithKey{}, http.StatusBadRequest, err } var result api.APIKeyResponseWithKey - r, err := at.Request(http.MethodPost, "/user/apikeys", nil, bb, nil, &result) + r, err := at.Request(http.MethodPost, "/user/apikeys", nil, b, nil, &result) return result, r.StatusCode, err } // UserAPIKeysPUT performs a `PUT /user/apikeys` Request. func (at *AccountsTester) UserAPIKeysPUT(akID primitive.ObjectID, body api.APIKeyPUT) (int, error) { - bb, err := json.Marshal(body) + b, err := json.Marshal(body) if err != nil { return http.StatusBadRequest, err } - r, err := at.Request(http.MethodPut, "/user/apikeys/"+akID.Hex(), nil, bb, nil, nil) + r, err := at.Request(http.MethodPut, "/user/apikeys/"+akID.Hex(), nil, b, nil, nil) return r.StatusCode, err } // UserAPIKeysPATCH performs a `PATCH /user/apikeys` Request. func (at *AccountsTester) UserAPIKeysPATCH(akID primitive.ObjectID, body api.APIKeyPATCH) (int, error) { - bb, err := json.Marshal(body) + b, err := json.Marshal(body) if err != nil { return http.StatusBadRequest, err } - r, err := at.Request(http.MethodPatch, "/user/apikeys/"+akID.Hex(), nil, bb, nil, nil) + r, err := at.Request(http.MethodPatch, "/user/apikeys/"+akID.Hex(), nil, b, nil, nil) return r.StatusCode, err } @@ -563,12 +569,12 @@ func (at *AccountsTester) UserPubkeyRegisterPOST(response, signature []byte) (ap Response: hex.EncodeToString(response), Signature: hex.EncodeToString(signature), } - bb, err := json.Marshal(body) + b, err := json.Marshal(body) if err != nil { return api.UserGET{}, http.StatusBadRequest, err } var result api.UserGET - r, err := at.Request(http.MethodPost, "/user/pubkey/register", nil, bb, nil, &result) + r, err := at.Request(http.MethodPost, "/user/pubkey/register", nil, b, nil, &result) return result, r.StatusCode, err } @@ -630,30 +636,10 @@ func (at *AccountsTester) UploadInfo(sl string) ([]api.UploadInfo, int, error) { if !database.ValidSkylinkHash(sl) { return nil, http.StatusBadRequest, database.ErrInvalidSkylink } - r, b, err := at.request(http.MethodGet, "/uploadinfo/"+sl, nil, nil, nil) - if err != nil { - return nil, r.StatusCode, err - } - if r.StatusCode != http.StatusOK { - return nil, r.StatusCode, errors.New(string(b)) - } var resp []api.UploadInfo - err = json.Unmarshal(b, &resp) + r, err := at.Request(http.MethodGet, "/uploadinfo/"+sl, nil, nil, nil, &resp) if err != nil { - return nil, http.StatusInternalServerError, errors.AddContext(err, "failed to marshal the body JSON") + return nil, r.StatusCode, err } return resp, r.StatusCode, nil } - -// UploadsDELETE performs `DELETE /user/uploads/:skylink` -// TODO Remove this one when we merge the tester refactoring. -func (at *AccountsTester) UploadsDELETE(skylink string) (int, error) { - r, b, err := at.request(http.MethodDelete, "/user/uploads/"+skylink, nil, nil, nil) - if err != nil { - return r.StatusCode, errors.AddContext(err, string(b)) - } - if r.StatusCode != http.StatusNoContent { - return r.StatusCode, errors.New("unexpected status code") - } - return r.StatusCode, nil -} diff --git a/test/tester_test.go b/test/tester_test.go index 176132b7..14b58332 100644 --- a/test/tester_test.go +++ b/test/tester_test.go @@ -2,8 +2,8 @@ package test import "testing" -// TestCleanName ensures that CleanName works as expected. -func TestCleanName(t *testing.T) { +// TestSanitizeName ensures that SanitizeName works as expected. +func TestSanitizeName(t *testing.T) { tests := map[string]struct { input string expected string @@ -21,7 +21,7 @@ func TestCleanName(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - out := CleanName(tt.input) + out := SanitizeName(tt.input) if out != tt.expected { t.Errorf("Test '%s' failed. Expected '%s', got '%s'", name, tt.expected, out) }