-
Notifications
You must be signed in to change notification settings - Fork 31
/
config.go
140 lines (127 loc) · 3.94 KB
/
config.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package main
import (
"errors"
"fmt"
"os"
"regexp"
"strconv"
"strings"
"github.com/imkira/gcp-iap-auth/jwt"
"github.com/namsral/flag"
)
const flagEnvPrefix = "GCP_IAP_AUTH"
var (
cfg = &jwt.Config{}
listenAddr = flag.String("listen-addr", "0.0.0.0", "Listen address")
listenPort = flag.String("listen-port", "", "Listen port (default: 80 for HTTP or 443 for HTTPS)")
audiences = flag.String("audiences", "", "Comma-separated list of JWT Audiences (elements can be paths like \"/projects/PROJECT_NUMBER/apps/PROJECT_ID\" or regular expressions like \"/^\\/projects\\/PROJECT_NUMBER/.*\" if you enclose them in slashes)")
domains = flag.String("domains", "", "Comma-separated list of allowed JWT Hosted Domains (hd) (optional)")
parsedDomains = []string{}
publicKeysPath = flag.String("public-keys", "", "Path to public keys file (optional)")
tlsCertPath = flag.String("tls-cert", "", "Path to TLS server's, intermediate's and CA's PEM certificate (optional)")
tlsKeyPath = flag.String("tls-key", "", "Path to TLS server's PEM key file (optional)")
backend = flag.String("backend", "", "Proxy authenticated requests to the specified URL (optional)")
emailHeader = flag.String("email-header", "X-WEBAUTH-USER", "In proxy mode, set the authenticated email address in the specified header")
)
func initConfig() error {
flag.EnvironmentPrefix = flagEnvPrefix
flag.CommandLine.Init(os.Args[0], flag.ExitOnError)
flag.Parse()
if err := initServerPort(); err != nil {
return err
}
if len(*audiences) == 0 {
return errors.New("You must specify --audiences")
}
if err := initAudiences(*audiences); err != nil {
return err
}
initDomains(*domains)
if err := initPublicKeys(*publicKeysPath); err != nil {
return err
}
return nil
}
func initServerPort() error {
if len(*listenPort) == 0 {
if len(*tlsCertPath) != 0 || len(*tlsKeyPath) != 0 {
*listenPort = "443"
} else {
*listenPort = "80"
}
}
if _, err := strconv.Atoi(*listenPort); err != nil {
return fmt.Errorf("Invalid listen port %q", *listenPort)
}
return nil
}
func initAudiences(audiences string) error {
str, err := extractAudiencesRegexp(audiences)
if err != nil {
return err
}
re, err := regexp.Compile(str)
if err != nil {
return fmt.Errorf("Invalid audiences regular expression %q (%v)", str, err)
}
cfg.MatchAudiences = re
return nil
}
func extractAudiencesRegexp(audiences string) (string, error) {
var strs []string
for _, audience := range strings.Split(audiences, ",") {
str, err := extractAudienceRegexp(audience)
if err != nil {
return "", err
}
strs = append(strs, str)
}
return strings.Join(strs, "|"), nil
}
func extractAudienceRegexp(audience string) (string, error) {
if strings.HasPrefix(audience, "/") && strings.HasSuffix(audience, "/") {
if len(audience) < 3 {
return "", fmt.Errorf("Invalid audiences regular expression %q", audience)
}
return audience[1 : len(audience)-1], nil
}
return parseRawAudience(audience)
}
func parseRawAudience(audience string) (string, error) {
aud, err := jwt.ParseAudience(audience)
if err != nil {
return "", fmt.Errorf("Invalid audience %q (%v)", audience, err)
}
return fmt.Sprintf("^%s$", regexp.QuoteMeta((string)(*aud))), nil
}
func initDomains(domains string) {
cfg.MatchDomains = map[string]bool{}
if len(domains) > 0 {
for _, domain := range strings.Split(domains, ",") {
if len(domain) > 0 {
cfg.MatchDomains[domain] = true
parsedDomains = append(parsedDomains, domain)
}
}
}
}
func initPublicKeys(filePath string) error {
var err error
if len(filePath) != 0 {
cfg.PublicKeys, err = loadPublicKeysFromFile(filePath)
} else {
cfg.PublicKeys, err = jwt.FetchPublicKeys()
}
if err != nil {
return err
}
return cfg.Validate()
}
func loadPublicKeysFromFile(filePath string) (map[string]jwt.PublicKey, error) {
f, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer f.Close()
return jwt.DecodePublicKeys(f)
}