diff --git a/conn.go b/conn.go index 1d9e4a22..c629019f 100644 --- a/conn.go +++ b/conn.go @@ -9,6 +9,7 @@ import ( "database/sql" "database/sql/driver" "encoding/binary" + "encoding/json" "errors" "fmt" "io" @@ -1143,6 +1144,10 @@ func isDriverSetting(key string) bool { return true case "password": return true + case "oauth_token": + return true + case "oauth_token_file": + return true case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni": return true case "fallback_application_name": @@ -1290,59 +1295,135 @@ func (cn *conn) auth(r *readBuf, o values) { // from the server.. case 10: - sc := scram.NewClient(sha256.New, o["user"], o["password"]) - sc.Step(nil) - if sc.Err() != nil { - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + switch saslMethod := r.string(); saslMethod { + case "SCRAM-SHA-256": + cn.saslScram(o) + case "OAUTHBEARER": + cn.saslOAuth(o) } - scOut := sc.Out() - w := cn.writeBuf('p') - w.string("SCRAM-SHA-256") - w.int32(len(scOut)) - w.bytes(scOut) - cn.send(w) + default: + errorf("unknown authentication response: %d", code) + } +} - t, r := cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) - } +func (cn *conn) saslScram(o values) { + sc := scram.NewClient(sha256.New, o["user"], o["password"]) + sc.Step(nil) + if sc.Err() != nil { + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + } + scOut := sc.Out() - if r.int32() != 11 { - errorf("unexpected authentication response: %q", t) - } + w := cn.writeBuf('p') + w.string("SCRAM-SHA-256") + w.int32(len(scOut)) + w.bytes(scOut) + cn.send(w) - nextStep := r.next(len(*r)) - sc.Step(nextStep) - if sc.Err() != nil { - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) - } + t, r := cn.recv() + if t != 'R' { + errorf("unexpected password response: %q", t) + } - scOut = sc.Out() - w = cn.writeBuf('p') - w.bytes(scOut) - cn.send(w) + if r.int32() != 11 { + errorf("unexpected authentication response: %q", t) + } - t, r = cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) - } + nextStep := r.next(len(*r)) + sc.Step(nextStep) + if sc.Err() != nil { + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + } - if r.int32() != 12 { - errorf("unexpected authentication response: %q", t) + scOut = sc.Out() + w = cn.writeBuf('p') + w.bytes(scOut) + cn.send(w) + + t, r = cn.recv() + if t != 'R' { + errorf("unexpected password response: %q", t) + } + + if r.int32() != 12 { + errorf("unexpected authentication response: %q", t) + } + + nextStep = r.next(len(*r)) + sc.Step(nextStep) + if sc.Err() != nil { + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + } +} + +func (cn *conn) saslOAuth(o values) { + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.1 + w := cn.writeBuf('p') + w.string("OAUTHBEARER") + + token, err := getOAuthToken(o) + if err != nil { + errorf("failed to obtain oauth token: %s", err) + } + initialResponse := []byte("n,,\x01auth=Bearer " + token + "\x01\x01") + w.int32(len(initialResponse)) + w.bytes(initialResponse) + cn.send(w) + + t, r := cn.recv() + if t != 'R' { + errorf("unexpected oauth response: %q", t) + } + + if code := r.int32(); code != 0 { + // usually on an authentication error we should get a + // AuthenticationSASLContinue message + if code != 11 { + errorf("unexpected oauth response: %q %d", t, code) } - nextStep = r.next(len(*r)) - sc.Step(nextStep) - if sc.Err() != nil { - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + // the AuthenticationSASLContinue does have an error payload + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.2 + errResponse := struct { + Status string `json:"status"` + Scope string `json:"scope"` + OpenIDConfiguration string `json:"openid-configuration"` + }{} + err := json.Unmarshal(*r, &errResponse) + if err != nil { + errorf("invalid oauth error response") } - default: - errorf("unknown authentication response: %d", code) + errorf("oauth authentication failed '%s'", errResponse.Status) + + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.3 + // we deliberately don't complete the error messaging sequence as described + // in 3.2.3 as we're going to close the connection either way + // w = cn.writeBuf('p') + // w.int32(1) + // w.bytes([]byte{0x01}) + // cn.send(w) } } +func getOAuthToken(o values) (string, error) { + if token, ok := o["oauth_token"]; ok { + return token, nil + } + + if tokenFile, ok := o["oauth_token_file"]; ok { + rawToken, err := os.ReadFile(tokenFile) + if err != nil { + return "", err + } + rawToken = bytes.TrimSuffix(rawToken, []byte("\n")) + return string(rawToken), nil + } + + return "", fmt.Errorf("no oauth token configured") +} + type format int const formatText format = 0