Skip to content

Commit

Permalink
Upgrade OAuth2 server to v4
Browse files Browse the repository at this point in the history
This involved renaming the package to use the correct module reference,
as well as finally being able to pass a context through to the database
calls. Yay, we do not need to engineer that solution!
  • Loading branch information
cjslep committed Dec 15, 2020
1 parent 931e745 commit 994f212
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 43 deletions.
2 changes: 1 addition & 1 deletion ap/c2s.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"github.com/go-fed/apcore/app"
"github.com/go-fed/apcore/framework/oauth2"
"github.com/go-fed/apcore/util"
oa2 "gopkg.in/oauth2.v3"
oa2 "github.com/go-oauth2/oauth2/v4"
)

var _ pub.SocialProtocol = &SocialBehavior{}
Expand Down
2 changes: 1 addition & 1 deletion ap/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
"github.com/go-fed/apcore/framework/oauth2"
"github.com/go-fed/apcore/services"
"github.com/go-fed/apcore/util"
oa2 "gopkg.in/oauth2.v3"
oa2 "github.com/go-oauth2/oauth2/v4"
)

var _ pub.CommonBehavior = &CommonBehavior{}
Expand Down
2 changes: 1 addition & 1 deletion app/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"net/url"

"github.com/go-fed/activity/streams/vocab"
"gopkg.in/oauth2.v3"
"github.com/go-oauth2/oauth2/v4"
)

// Framework provides request-time hooks for use in handlers.
Expand Down
2 changes: 1 addition & 1 deletion framework/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"github.com/go-fed/activity/streams/vocab"
"github.com/go-fed/apcore/app"
"github.com/go-fed/apcore/framework/oauth2"
oa2 "gopkg.in/oauth2.v3"
oa2 "github.com/go-oauth2/oauth2/v4"
)

var _ app.Framework = &Framework{}
Expand Down
2 changes: 1 addition & 1 deletion framework/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func BuildHandler(r *Router,
return
}
if authd {
if err = oauth.RemoveByAccess(t); err != nil {
if err = oauth.RemoveByAccess(util.Context{r.Context()}, t); err != nil {
internalErrorHandler.ServeHTTP(w, r)
return
}
Expand Down
21 changes: 10 additions & 11 deletions framework/oauth2/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@ import (
"github.com/go-fed/apcore/framework/web"
"github.com/go-fed/apcore/services"
"github.com/go-fed/apcore/util"
"gopkg.in/oauth2.v3"
"gopkg.in/oauth2.v3/errors"
oaerrors "gopkg.in/oauth2.v3/errors"
"gopkg.in/oauth2.v3/manage"
oaserver "gopkg.in/oauth2.v3/server"
"github.com/go-oauth2/oauth2/v4"
oaerrors "github.com/go-oauth2/oauth2/v4/errors"
"github.com/go-oauth2/oauth2/v4/manage"
oaserver "github.com/go-oauth2/oauth2/v4/server"
)

type Server struct {
Expand Down Expand Up @@ -140,16 +139,16 @@ func NewServer(c *config.Config, a app.Application, d *services.OAuth2, y *servi
}
return
})
srv.SetInternalErrorHandler(func(err error) (re *errors.Response) {
re = &errors.Response{
Error: errors.ErrServerError,
srv.SetInternalErrorHandler(func(err error) (re *oaerrors.Response) {
re = &oaerrors.Response{
Error: oaerrors.ErrServerError,
ErrorCode: http.StatusInternalServerError,
Description: "Internal Error",
StatusCode: http.StatusInternalServerError,
}
return
})
srv.SetResponseErrorHandler(func(re *errors.Response) {
srv.SetResponseErrorHandler(func(re *oaerrors.Response) {
util.ErrorLogger.Errorf("oauth2 response error: %s", re.Error.Error())
})
s = &Server{
Expand Down Expand Up @@ -188,6 +187,6 @@ func (o *Server) ValidateOAuth2AccessToken(w http.ResponseWriter, r *http.Reques
return
}

func (o *Server) RemoveByAccess(t oauth2.TokenInfo) error {
return o.m.RemoveAccessToken(t.GetAccess())
func (o *Server) RemoveByAccess(ctx util.Context, t oauth2.TokenInfo) error {
return o.m.RemoveAccessToken(ctx.Context, t.GetAccess())
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ require (
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect
github.com/go-fed/activity v1.0.1-0.20201213224552-472d90163f3a
github.com/go-fed/httpsig v0.1.1-0.20190924171022-f4c36041199d
github.com/go-oauth2/oauth2/v4 v4.1.2
github.com/google/logger v1.0.1
github.com/google/uuid v1.1.2
github.com/gorilla/mux v1.8.0
github.com/gorilla/sessions v1.2.0
github.com/jackc/pgx/v4 v4.9.0
github.com/manifoldco/promptui v0.3.2
github.com/nicksnyder/go-i18n v1.10.1 // indirect
github.com/tidwall/gjson v1.1.3
github.com/tidwall/gjson v1.6.0
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2
gopkg.in/alecthomas/kingpin.v3-unstable v3.0.0-20191105091915-95d230a53780 // indirect
gopkg.in/ini.v1 v1.44.0
gopkg.in/oauth2.v3 v3.10.0
)
76 changes: 76 additions & 0 deletions go.sum

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion models/client_infos.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
"database/sql"

"github.com/go-fed/apcore/util"
"gopkg.in/oauth2.v3"
"github.com/go-oauth2/oauth2/v4"
)

var _ oauth2.ClientInfo = &ClientInfo{}
Expand Down
2 changes: 1 addition & 1 deletion models/test/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import (
"github.com/go-fed/apcore/framework/db"
"github.com/go-fed/apcore/models"
"github.com/go-fed/apcore/util"
"github.com/go-oauth2/oauth2/v4"
_ "github.com/jackc/pgx/v4/stdlib"
"gopkg.in/oauth2.v3"
)

var dburl = flag.String("db", "", "database url to connect to")
Expand Down
2 changes: 1 addition & 1 deletion models/token_infos.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"time"

"github.com/go-fed/apcore/util"
"gopkg.in/oauth2.v3"
"github.com/go-oauth2/oauth2/v4"
)

var _ oauth2.TokenInfo = &TokenInfo{}
Expand Down
39 changes: 17 additions & 22 deletions services/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (

"github.com/go-fed/apcore/models"
"github.com/go-fed/apcore/util"
"gopkg.in/oauth2.v3"
"github.com/go-oauth2/oauth2/v4"
)

var _ oauth2.ClientStore = &OAuth2{}
Expand All @@ -35,65 +35,60 @@ type OAuth2 struct {
Token *models.TokenInfos
}

// TODO: Somehow pass in context instead
func (o *OAuth2) context() util.Context {
return util.Context{context.Background()}
}

func (o *OAuth2) GetByID(id string) (ci oauth2.ClientInfo, err error) {
c := o.context()
func (o *OAuth2) GetByID(ctx context.Context, id string) (ci oauth2.ClientInfo, err error) {
c := util.Context{ctx}
return ci, doInTx(c, o.DB, func(tx *sql.Tx) error {
ci, err = o.Client.GetByID(c, tx, id)
return err
})
}

func (o *OAuth2) Create(info oauth2.TokenInfo) error {
c := o.context()
func (o *OAuth2) Create(ctx context.Context, info oauth2.TokenInfo) error {
c := util.Context{ctx}
return doInTx(c, o.DB, func(tx *sql.Tx) error {
return o.Token.Create(c, tx, info)
})
}

func (o *OAuth2) RemoveByCode(code string) error {
c := o.context()
func (o *OAuth2) RemoveByCode(ctx context.Context, code string) error {
c := util.Context{ctx}
return doInTx(c, o.DB, func(tx *sql.Tx) error {
return o.Token.RemoveByCode(c, tx, code)
})
}

func (o *OAuth2) RemoveByAccess(access string) error {
c := o.context()
func (o *OAuth2) RemoveByAccess(ctx context.Context, access string) error {
c := util.Context{ctx}
return doInTx(c, o.DB, func(tx *sql.Tx) error {
return o.Token.RemoveByAccess(c, tx, access)
})
}

func (o *OAuth2) RemoveByRefresh(refresh string) error {
c := o.context()
func (o *OAuth2) RemoveByRefresh(ctx context.Context, refresh string) error {
c := util.Context{ctx}
return doInTx(c, o.DB, func(tx *sql.Tx) error {
return o.Token.RemoveByRefresh(c, tx, refresh)
})
}

func (o *OAuth2) GetByCode(code string) (ti oauth2.TokenInfo, err error) {
c := o.context()
func (o *OAuth2) GetByCode(ctx context.Context, code string) (ti oauth2.TokenInfo, err error) {
c := util.Context{ctx}
return ti, doInTx(c, o.DB, func(tx *sql.Tx) error {
ti, err = o.Token.GetByCode(c, tx, code)
return err
})
}

func (o *OAuth2) GetByAccess(access string) (ti oauth2.TokenInfo, err error) {
c := o.context()
func (o *OAuth2) GetByAccess(ctx context.Context, access string) (ti oauth2.TokenInfo, err error) {
c := util.Context{ctx}
return ti, doInTx(c, o.DB, func(tx *sql.Tx) error {
ti, err = o.Token.GetByAccess(c, tx, access)
return err
})
}

func (o *OAuth2) GetByRefresh(refresh string) (ti oauth2.TokenInfo, err error) {
c := o.context()
func (o *OAuth2) GetByRefresh(ctx context.Context, refresh string) (ti oauth2.TokenInfo, err error) {
c := util.Context{ctx}
return ti, doInTx(c, o.DB, func(tx *sql.Tx) error {
ti, err = o.Token.GetByRefresh(c, tx, refresh)
return err
Expand Down

0 comments on commit 994f212

Please sign in to comment.