-
Notifications
You must be signed in to change notification settings - Fork 2
/
handlers_session.go
103 lines (83 loc) · 2.53 KB
/
handlers_session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package main
import (
"context"
"fmt"
"log"
"net/http"
"github.com/gorilla/sessions"
"github.com/mdbot/wiki/config"
)
const (
sessionName = "wiki"
sessionUserKey = "user"
sessionSessionKey = "session"
sessionNoticeKey = "notice"
sessionErrorKey = "error"
contextUserKey = "user"
contextErrorKey = "error"
contextNoticeKey = "notice"
contextSessionKey = "session"
sessionKeyFormat = "wiki:%x"
)
type UserProvider interface {
User(string) *config.User
}
func SessionHandler(up UserProvider, store sessions.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
s, _ := store.Get(request, sessionName)
if username, ok := s.Values[sessionUserKey]; ok {
user := up.User(username.(string))
if user != nil {
if key := s.Values[sessionSessionKey]; fmt.Sprintf(sessionKeyFormat, user.SessionKey) == key {
request = request.WithContext(context.WithValue(request.Context(), contextUserKey, user))
}
}
}
if e, ok := s.Values[sessionErrorKey]; ok {
request = request.WithContext(context.WithValue(request.Context(), contextErrorKey, e))
}
if e, ok := s.Values[sessionNoticeKey]; ok {
request = request.WithContext(context.WithValue(request.Context(), contextNoticeKey, e))
}
request = request.WithContext(context.WithValue(request.Context(), contextSessionKey, s))
next.ServeHTTP(writer, request)
})
}
}
func putSessionKey(w http.ResponseWriter, r *http.Request, key string, value interface{}) {
if s := getSessionForRequest(r); s != nil {
s.Values[key] = value
if s.IsNew {
s.Options.HttpOnly = true
s.Options.SameSite = http.SameSiteStrictMode
s.Options.MaxAge = 60 * 60 * 24 * 31
}
if err := s.Save(r, w); err != nil {
log.Printf("Unable to save session: %v", err)
}
}
}
func getUserForRequest(r *http.Request) *config.User {
v, _ := r.Context().Value(contextUserKey).(*config.User)
return v
}
func getErrorForRequest(r *http.Request) string {
v, _ := r.Context().Value(contextErrorKey).(string)
return v
}
func getNoticeForRequest(r *http.Request) string {
v, _ := r.Context().Value(contextNoticeKey).(string)
return v
}
func getSessionForRequest(r *http.Request) *sessions.Session {
v, _ := r.Context().Value(contextSessionKey).(*sessions.Session)
return v
}
func clearSessionKey(w http.ResponseWriter, r *http.Request, key string) {
s := getSessionForRequest(r)
if s != nil {
delete(s.Values, key)
_ = s.Save(r, w)
}
}