diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 8cf1ed9fc..52ef1042f 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "strconv" "strings" + "net/http" "github.com/labstack/echo/v4" ) @@ -74,10 +75,13 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { l := len(basic) if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { + // Invalid base64 shouldn't be treated as error + // instead should be treated as invalid client input b, err := base64.StdEncoding.DecodeString(auth[l+1:]) if err != nil { - return err + return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err) } + cred := string(b) for i := 0; i < len(cred); i++ { if cred[i] == ':' { diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 76039db0a..4c355aa16 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -58,6 +58,12 @@ func TestBasicAuth(t *testing.T) { assert.Equal(http.StatusUnauthorized, he.Code) assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) + // Invalid base64 string + auth = basic + " invalidString" + req.Header.Set(echo.HeaderAuthorization, auth) + he = h(c).(*echo.HTTPError) + assert.Equal(http.StatusBadRequest, he.Code) + // Missing Authorization header req.Header.Del(echo.HeaderAuthorization) he = h(c).(*echo.HTTPError)