Skip to content

Commit

Permalink
Merge pull request #7 from getmeemaw/clientEncryption
Browse files Browse the repository at this point in the history
Encrypt server shares with client stored key
  • Loading branch information
marceaul authored Jun 14, 2024
2 parents 2fd30f3 + 09693a6 commit 4464bd6
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 125 deletions.
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ require (
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1
github.com/rs/cors v1.7.0
github.com/tabbed/pqtype v0.0.0-00010101000000-000000000000
golang.org/x/crypto v0.23.0
golang.org/x/mobile v0.0.0-20230922142353-e2f452493d57
google.golang.org/protobuf v1.34.1
Expand Down
14 changes: 6 additions & 8 deletions server/database/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 24 additions & 27 deletions server/database/query.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 16 additions & 17 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import (
"nhooyr.io/websocket"
)

type ContextKey string

// identityMiddleware is a middleware used to get the userId from auth provider based on a generic bearer token provided by the client
// used by /identify and /authorize
func (server *Server) identityMiddleware(next http.Handler) http.Handler {
Expand Down Expand Up @@ -64,22 +62,23 @@ func (server *Server) identityMiddleware(next http.Handler) http.Handler {

// Store userId in context for next request in the stack

ctx = context.WithValue(ctx, ContextKey("userId"), userId)
ctx = context.WithValue(ctx, types.ContextKey("userId"), userId)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

// ServeWasm is responsible for serving the wasm module
func (server *Server) ServeWasm(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/wasm")
w.Header().Set("Access-Control-Allow-Origin", server._config.ClientOrigin)
w.Write(server._wasm)
}

// IdentifyHandler is responsible for getting a unique identifier of a user from the auth provider
// It uses identityMiddleware to get the userId from auth provider based on a generic bearer token provided by the client, then returns it
func (server *Server) IdentifyHandler(w http.ResponseWriter, r *http.Request) {
// Get userId from context
userId, ok := r.Context().Value(ContextKey("userId")).(string)
userId, ok := r.Context().Value(types.ContextKey("userId")).(string)
if !ok {
log.Println("Authorization info not found")
http.Error(w, "Authorization info not found", http.StatusUnauthorized)
Expand All @@ -100,14 +99,14 @@ type tokenParameters struct {
// It then creates an access token linked to that userId, stores it in cache and returns it
func (server *Server) AuthorizeHandler(w http.ResponseWriter, r *http.Request) {
// Get userId from context
userId, ok := r.Context().Value(ContextKey("userId")).(string)
userId, ok := r.Context().Value(types.ContextKey("userId")).(string)
if !ok {
http.Error(w, "Authorization info not found", http.StatusUnauthorized)
return
}

// Get metadata from context
metadata, ok := r.Context().Value(ContextKey("metadata")).(string)
metadata, ok := r.Context().Value(types.ContextKey("metadata")).(string)
if !ok {
http.Error(w, "Authorization info not found", http.StatusUnauthorized)
return
Expand Down Expand Up @@ -161,9 +160,9 @@ func (server *Server) authMiddleware(next http.Handler) http.Handler {

// Add the userId and token to the context
ctx := r.Context()
ctx = context.WithValue(ctx, ContextKey("userId"), tokenParams.userId)
ctx = context.WithValue(ctx, ContextKey("metadata"), tokenParams.metadata)
ctx = context.WithValue(ctx, ContextKey("token"), tokenParam[0])
ctx = context.WithValue(ctx, types.ContextKey("userId"), tokenParams.userId)
ctx = context.WithValue(ctx, types.ContextKey("metadata"), tokenParams.metadata)
ctx = context.WithValue(ctx, types.ContextKey("token"), tokenParam[0])
next.ServeHTTP(w, r.WithContext(ctx))
})
}
Expand All @@ -173,13 +172,13 @@ func (server *Server) authMiddleware(next http.Handler) http.Handler {
// stores the result of dkg in DB (new wallet)
func (server *Server) DkgHandler(w http.ResponseWriter, r *http.Request) {
// Get userId and access token from context
userId, ok := r.Context().Value(ContextKey("userId")).(string)
userId, ok := r.Context().Value(types.ContextKey("userId")).(string)
if !ok {
http.Error(w, "Authorization info not found", http.StatusUnauthorized)
return
}

token, ok := r.Context().Value(ContextKey("token")).(string)
token, ok := r.Context().Value(types.ContextKey("token")).(string)
if !ok {
http.Error(w, "Authorization info not found", http.StatusUnauthorized)
return
Expand Down Expand Up @@ -299,7 +298,7 @@ func (server *Server) DkgHandler(w http.ResponseWriter, r *http.Request) {
// In the future: verifies that client has been able to store wallet; if not, remove in DB
// Idea: "validated" status for the wallet, which becomes True after calling DkgTwo; if False, can be overwritten
func (server *Server) DkgTwoHandler(w http.ResponseWriter, r *http.Request) {
token, ok := r.Context().Value(ContextKey("token")).(string)
token, ok := r.Context().Value(types.ContextKey("token")).(string)
if !ok {
http.Error(w, "Authorization info not found", http.StatusUnauthorized)
return
Expand Down Expand Up @@ -328,14 +327,14 @@ func (server *Server) DkgTwoHandler(w http.ResponseWriter, r *http.Request) {
// requires a hex-encoded message to be signed (provided in URL parameter)
func (server *Server) SignHandler(w http.ResponseWriter, r *http.Request) {
// Get userId and access token from context
userId, ok := r.Context().Value(ContextKey("userId")).(string)
userId, ok := r.Context().Value(types.ContextKey("userId")).(string)
if !ok {
// If there's no userID in the context, report an error and return.
http.Error(w, "Authorization info not found", http.StatusUnauthorized)
return
}

token, ok := r.Context().Value(ContextKey("token")).(string)
token, ok := r.Context().Value(types.ContextKey("token")).(string)
if !ok {
// If there's no token in the context, report an error and return.
http.Error(w, "Authorization info not found", http.StatusUnauthorized)
Expand Down Expand Up @@ -439,14 +438,14 @@ func (server *Server) RecoverHandler(w http.ResponseWriter, r *http.Request) {
}

// Get userId and access token from context
userId, ok := r.Context().Value(ContextKey("userId")).(string)
userId, ok := r.Context().Value(types.ContextKey("userId")).(string)
if !ok {
// If there's no userID in the context, report an error and return.
http.Error(w, "Authorization info not found", http.StatusUnauthorized)
return
}

token, ok := r.Context().Value(ContextKey("token")).(string)
token, ok := r.Context().Value(types.ContextKey("token")).(string)
if !ok {
// If there's no token in the context, report an error and return.
http.Error(w, "Authorization info not found", http.StatusUnauthorized)
Expand Down Expand Up @@ -569,7 +568,7 @@ func (server *Server) headerMiddleware(next http.Handler) http.Handler {
// Remove the prefix and use the remaining part as the context key
key := strings.ToLower(strings.TrimPrefix(name, headerPrefix))
// Add the first header value to the context
ctx = context.WithValue(ctx, ContextKey(key), values[0])
ctx = context.WithValue(ctx, types.ContextKey(key), values[0])
}
}

Expand Down
27 changes: 16 additions & 11 deletions server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/patrickmn/go-cache"
"github.com/rs/cors"
)

type Server struct {
Expand Down Expand Up @@ -54,22 +53,13 @@ func NewServer(vault Vault, config *Config, wasmBinary []byte, logging bool) *Se

// Router

_cors := cors.New(cors.Options{
AllowedOrigins: []string{server._config.ClientOrigin},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
// AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: true,
MaxAge: 300,
})

r := chi.NewRouter()

// global middlewares
if logging {
r.Use(middleware.Logger)
}
r.Use(_cors.Handler)
r.Use(server.corsMiddleware)
// r.Use(cors.Default().Handler)
r.Use(server.headerMiddleware)

Expand Down Expand Up @@ -154,3 +144,18 @@ type Config struct {
SupabaseUrl string
SupabaseApiKey string
}

func (server *Server) corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", server._config.ClientOrigin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, M-METADATA")

if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}

next.ServeHTTP(w, r)
})
}
4 changes: 2 additions & 2 deletions server/sqlc/query.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ ON CONFLICT DO NOTHING
RETURNING *;

-- name: AddWallet :one
INSERT INTO wallets (user_id, public_address, share, params)
VALUES (sqlc.arg('UserId'), sqlc.arg('PublicAddress'), sqlc.arg('Share'), sqlc.arg('Params'))
INSERT INTO wallets (user_id, public_address, encrypted_dkg_results, nonce)
VALUES (sqlc.arg('UserId'), sqlc.arg('PublicAddress'), sqlc.arg('EncryptedDkgResults'), sqlc.arg('Nonce'))
ON CONFLICT DO NOTHING
RETURNING *;

Expand Down
4 changes: 2 additions & 2 deletions server/sqlc/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ CREATE TABLE wallets (
id BIGSERIAL PRIMARY KEY,
user_id bigint NOT NULL REFERENCES users(id) ON DELETE RESTRICT ON UPDATE RESTRICT,
public_address text NOT NULL DEFAULT '',
share text NOT NULL DEFAULT '',
params jsonb NOT NULL
encrypted_dkg_results bytea NOT NULL DEFAULT E'\\x',
nonce bytea NOT NULL DEFAULT E'\\x'
);

-- CREATE UNIQUE INDEX wallet_identifier ON public.wallets USING btree (user_id, public_address);
Expand Down
Loading

0 comments on commit 4464bd6

Please sign in to comment.