Skip to content

Commit

Permalink
fix: organization tokens (#660)
Browse files Browse the repository at this point in the history
Organization tokens have been broken since they were refactored in
104013d. The JWT contains the _organization token_ _ID_ but the
middleware was checking the _organization_ _name_.

Might fix #658.
  • Loading branch information
leg100 authored Dec 7, 2023
1 parent ed9b1fd commit be82c55
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 47 deletions.
57 changes: 57 additions & 0 deletions internal/integration/organization_token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package integration

import (
"testing"

"github.com/leg100/otf/internal"
"github.com/leg100/otf/internal/api"
"github.com/leg100/otf/internal/organization"
"github.com/leg100/otf/internal/workspace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// TestIntegration_OrganizationTokens demonstrates the use of an organization
// token to authenticate via the API.
func TestIntegration_OrganizationTokens(t *testing.T) {
integrationTest(t)

daemon, org, ctx := setup(t, nil)

ot, token, err := daemon.CreateOrganizationToken(ctx, organization.CreateOrganizationTokenOptions{
Organization: org.Name,
})
require.NoError(t, err)
assert.Equal(t, org.Name, ot.Organization)

apiClient, err := api.NewClient(api.Config{
Address: daemon.Hostname(),
Token: string(token),
})
require.NoError(t, err)

// create some workspaces and attempt to list them using client
// authenticating with an organization token
daemon.createWorkspace(t, ctx, org)
daemon.createWorkspace(t, ctx, org)
daemon.createWorkspace(t, ctx, org)

wsClient := &workspace.Client{Client: apiClient}
got, err := wsClient.ListWorkspaces(ctx, workspace.ListOptions{
Organization: internal.String(org.Name),
})
require.NoError(t, err)
assert.Equal(t, 3, len(got.Items))

// re-generate token
_, _, err = daemon.CreateOrganizationToken(ctx, organization.CreateOrganizationTokenOptions{
Organization: org.Name,
})
require.NoError(t, err)

// access with previous token should now be refused
_, err = wsClient.ListWorkspaces(ctx, workspace.ListOptions{
Organization: internal.String(org.Name),
})
require.Equal(t, internal.ErrUnauthorized, err)
}
44 changes: 32 additions & 12 deletions internal/organization/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,26 @@ func (db *pgdb) delete(ctx context.Context, name string) error {
// Organization tokens
//

// tokenRow is the row result of a database query for organization tokens
type tokenRow struct {
OrganizationTokenID pgtype.Text `json:"organization_token_id"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
OrganizationName pgtype.Text `json:"organization_name"`
Expiry pgtype.Timestamptz `json:"expiry"`
}

func (result tokenRow) toToken() *OrganizationToken {
ot := &OrganizationToken{
ID: result.OrganizationTokenID.String,
CreatedAt: result.CreatedAt.Time.UTC(),
Organization: result.OrganizationName.String,
}
if result.Expiry.Status == pgtype.Present {
ot.Expiry = internal.Time(result.Expiry.Time.UTC())
}
return ot
}

func (db *pgdb) upsertOrganizationToken(ctx context.Context, token *OrganizationToken) error {
_, err := db.Conn(ctx).UpsertOrganizationToken(ctx, pggen.UpsertOrganizationTokenParams{
OrganizationTokenID: sql.String(token.ID),
Expand All @@ -189,23 +209,23 @@ func (db *pgdb) upsertOrganizationToken(ctx context.Context, token *Organization
}

func (db *pgdb) getOrganizationTokenByName(ctx context.Context, organization string) (*OrganizationToken, error) {
// query only returns 0 or 1 tokens
result, err := db.Conn(ctx).FindOrganizationTokensByName(ctx, sql.String(organization))
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, nil
return nil, sql.Error(err)
}
ot := &OrganizationToken{
ID: result[0].OrganizationTokenID.String,
CreatedAt: result[0].CreatedAt.Time.UTC(),
Organization: result[0].OrganizationName.String,
return tokenRow(result).toToken(), nil
}

func (db *pgdb) listOrganizationTokens(ctx context.Context, organization string) ([]*OrganizationToken, error) {
result, err := db.Conn(ctx).FindOrganizationTokens(ctx, sql.String(organization))
if err != nil {
return nil, sql.Error(err)
}
if result[0].Expiry.Status == pgtype.Present {
ot.Expiry = internal.Time(result[0].Expiry.Time.UTC())
items := make([]*OrganizationToken, len(result))
for i, r := range result {
items[i] = tokenRow(r).toToken()
}
return ot, nil
return items, nil
}

func (db *pgdb) getOrganizationTokenByID(ctx context.Context, tokenID string) (*OrganizationToken, error) {
Expand Down
38 changes: 30 additions & 8 deletions internal/organization/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type (
// GetOrganizationToken gets the organization token. If a token does not
// exist, then nil is returned without an error.
GetOrganizationToken(ctx context.Context, organization string) (*OrganizationToken, error)
ListOrganizationTokens(ctx context.Context, organization string) ([]*OrganizationToken, error)
DeleteOrganizationToken(ctx context.Context, organization string) error
WatchOrganizations(context.Context) (<-chan pubsub.Event[*Organization], func())
getOrganizationTokenByID(ctx context.Context, tokenID string) (*OrganizationToken, error)
Expand Down Expand Up @@ -114,10 +115,9 @@ func NewService(opts Options) *service {
opts.Responder.Register(tfeapi.IncludeOrganization, svc.tfeapi.include)
// Register with auth middleware the organization token and a means of
// retrieving organization corresponding to token.
opts.TokensService.RegisterKind(OrganizationTokenKind, func(ctx context.Context, organization string) (internal.Subject, error) {
return svc.GetOrganizationToken(ctx, organization)
opts.TokensService.RegisterKind(OrganizationTokenKind, func(ctx context.Context, tokenID string) (internal.Subject, error) {
return svc.getOrganizationTokenByID(ctx, tokenID)
})

return &svc
}

Expand Down Expand Up @@ -305,7 +305,33 @@ func (s *service) CreateOrganizationToken(ctx context.Context, opts CreateOrgani
}

func (s *service) GetOrganizationToken(ctx context.Context, organization string) (*OrganizationToken, error) {
return s.db.getOrganizationTokenByName(ctx, organization)
ot, err := s.db.getOrganizationTokenByName(ctx, organization)
if err != nil {
s.Error(err, "retrieving organization token", "organization", organization)
return nil, err
}
s.V(0).Info("retrieved organization token", "organization", organization)
return ot, nil
}

func (s *service) getOrganizationTokenByID(ctx context.Context, tokenID string) (*OrganizationToken, error) {
ot, err := s.db.getOrganizationTokenByID(ctx, tokenID)
if err != nil {
s.Error(err, "retrieving organization token", "token_id", tokenID)
return nil, err
}
s.V(0).Info("retrieved organization token", "token_id", tokenID, "organization", ot.Organization)
return ot, nil
}

func (s *service) ListOrganizationTokens(ctx context.Context, organization string) ([]*OrganizationToken, error) {
tokens, err := s.db.listOrganizationTokens(ctx, organization)
if err != nil {
s.Error(err, "listing organization tokens", "organization", organization)
return nil, err
}
s.V(0).Info("listed organization tokens", "organization", organization, "count", len(tokens))
return tokens, nil
}

func (s *service) DeleteOrganizationToken(ctx context.Context, organization string) error {
Expand All @@ -323,7 +349,3 @@ func (s *service) DeleteOrganizationToken(ctx context.Context, organization stri

return nil
}

func (s *service) getOrganizationTokenByID(ctx context.Context, tokenID string) (*OrganizationToken, error) {
return s.db.getOrganizationTokenByID(ctx, tokenID)
}
8 changes: 6 additions & 2 deletions internal/organization/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,16 @@ func (a *web) organizationToken(w http.ResponseWriter, r *http.Request) {
a.Error(w, err.Error(), http.StatusUnprocessableEntity)
return
}
token, err := a.svc.GetOrganizationToken(r.Context(), org)
// ListOrganizationTokens should only ever return either 0 or 1 token
tokens, err := a.svc.ListOrganizationTokens(r.Context(), org)
if err != nil {
a.Error(w, err.Error(), http.StatusInternalServerError)
return
}

var token *OrganizationToken
if len(tokens) > 0 {
token = tokens[0]
}
a.Render("organization_token.tmpl", w, struct {
OrganizationPage
Token *OrganizationToken
Expand Down
14 changes: 12 additions & 2 deletions internal/sql/pggen/agent.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

79 changes: 58 additions & 21 deletions internal/sql/pggen/organization_token.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion internal/sql/queries/organization_token.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ INSERT INTO organization_tokens (
organization_token_id = pggen.arg('organization_token_id'),
expiry = pggen.arg('expiry');

-- name: FindOrganizationTokensByName :many
-- name: FindOrganizationTokens :many
SELECT *
FROM organization_tokens
WHERE organization_name = pggen.arg('organization_name');

-- name: FindOrganizationTokensByName :one
SELECT *
FROM organization_tokens
WHERE organization_name = pggen.arg('organization_name');
Expand Down
3 changes: 3 additions & 0 deletions internal/tokens/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"strings"

"github.com/go-logr/logr"
"github.com/gorilla/mux"
"github.com/leg100/otf/internal"
otfapi "github.com/leg100/otf/internal/api"
Expand Down Expand Up @@ -35,6 +36,7 @@ var AuthenticatedPrefixes = []string{
type (
middlewareOptions struct {
GoogleIAPConfig
logr.Logger

key jwk.Key

Expand Down Expand Up @@ -94,6 +96,7 @@ func newMiddleware(opts middlewareOptions) mux.MiddlewareFunc {
} else if bearer := r.Header.Get("Authorization"); bearer != "" {
subject, err = mw.validateBearer(ctx, bearer)
if err != nil {
mw.Error(err, "validating bearer token")
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
Expand Down
Loading

0 comments on commit be82c55

Please sign in to comment.