From 18f65b96f1e2ae0de0b5ca3677f485c0696608a0 Mon Sep 17 00:00:00 2001 From: fmoor Date: Tue, 1 Dec 2020 11:51:47 -0700 Subject: [PATCH] support similar connect options as other clients This change adopts a connect API similar to the java script client accepting a dsn string and an options object. fixes https://github.com/edgedb/edgedb-go/issues/22 --- .github/workflows/tests.yml | 4 +- connect.go | 12 +- connect_test.go | 11 +- connutils.go | 391 ++++++++++++++++++++++++++++++++++ connutils_test.go | 411 ++++++++++++++++++++++++++++++++++++ credentials.go | 107 ++++++++++ credentials1.json | 6 + credentials_test.go | 80 +++++++ edgedb.go | 55 +++-- error.go | 33 +-- fallthrough.go | 10 +- granular_flow.go | 4 +- main_test.go | 42 ++-- options.go | 122 +++++------ options_test.go | 84 -------- pool.go | 44 ++-- pool_test.go | 59 +++--- query_test.go | 4 +- tutorial_test.go | 5 +- unix.go | 21 ++ windows.go | 21 ++ 21 files changed, 1273 insertions(+), 253 deletions(-) create mode 100644 connutils.go create mode 100644 connutils_test.go create mode 100644 credentials.go create mode 100644 credentials1.json create mode 100644 credentials_test.go delete mode 100644 options_test.go create mode 100644 unix.go create mode 100644 windows.go diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9a07cd14..8d0b6d60 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,7 +38,7 @@ jobs: - name: Install EdgeDB env: OS_NAME: ${{ matrix.os }} - SLOT: 1-alpha4 + SLOT: 1-alpha7-dev5249 run: | curl https://packages.edgedb.com/keys/edgedb.asc \ | sudo apt-key add - @@ -55,6 +55,6 @@ jobs: - name: Test env: - EDGEDB_SLOT: 1-alpha4 + EDGEDB_SLOT: 1-alpha7-dev5249 run: | go test -race ./... diff --git a/connect.go b/connect.go index 6bc246cf..2d014972 100644 --- a/connect.go +++ b/connect.go @@ -25,16 +25,16 @@ import ( "github.com/xdg/scram" ) -func (c *baseConn) connect(ctx context.Context, opts *Options) error { +func (c *baseConn) connect(ctx context.Context, cfg *connConfig) error { buf := buff.New(nil) buf.BeginMessage(message.ClientHandshake) buf.PushUint16(0) // major version buf.PushUint16(8) // minor version buf.PushUint16(2) // number of parameters buf.PushString("database") - buf.PushString(opts.Database) + buf.PushString(cfg.database) buf.PushString("user") - buf.PushString(opts.User) + buf.PushString(cfg.user) buf.PushUint16(0) // no extensions buf.EndMessage() @@ -78,7 +78,7 @@ func (c *baseConn) connect(ctx context.Context, opts *Options) error { buf.PopBytes() } - if err := c.authenticate(ctx, opts); err != nil { + if err := c.authenticate(ctx, cfg); err != nil { return err } case message.ErrorResponse: @@ -90,8 +90,8 @@ func (c *baseConn) connect(ctx context.Context, opts *Options) error { return nil } -func (c *baseConn) authenticate(ctx context.Context, opts *Options) error { - client, err := scram.SHA256.NewClient(opts.User, opts.Password, "") +func (c *baseConn) authenticate(ctx context.Context, cfg *connConfig) error { + client, err := scram.SHA256.NewClient(cfg.user, cfg.password, "") if err != nil { return err } diff --git a/connect_test.go b/connect_test.go index 57ba88a6..bfec7b45 100644 --- a/connect_test.go +++ b/connect_test.go @@ -26,17 +26,10 @@ import ( ) func TestAuth(t *testing.T) { - var host string - if opts.admin { - host = "localhost" - } else { - host = opts.Host - } - ctx := context.Background() conn, err := ConnectOne(ctx, Options{ - Host: host, - Port: opts.Port, + Hosts: opts.Hosts, + Ports: opts.Ports, User: "user_with_password", Password: "secret", Database: opts.Database, diff --git a/connutils.go b/connutils.go new file mode 100644 index 00000000..6ea06f63 --- /dev/null +++ b/connutils.go @@ -0,0 +1,391 @@ +// This source file is part of the EdgeDB open source project. +// +// Copyright 2020-present EdgeDB Inc. and the EdgeDB authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package edgedb + +import ( + "fmt" + "net/url" + "os" + usr "os/user" + "path" + "regexp" + "strconv" + "strings" + "time" +) + +const edgedbPort = 5656 + +type connConfig struct { + addrs []dialArgs + user string + password string + database string + connectTimeout time.Duration + serverSettings map[string]string +} + +type dialArgs struct { + network string + address string +} + +func validatePortSpec(hosts []string, ports []int) ([]int, error) { + var result []int + if len(ports) > 1 { + if len(ports) != len(hosts) { + return nil, fmt.Errorf( + "could not match %v port numbers to %v hosts%w", + len(ports), len(hosts), ErrInterfaceViolation, + ) + } + + result = ports + } else { + result = make([]int, len(hosts)) + for i := 0; i < len(hosts); i++ { + result[i] = ports[0] + } + } + + return result, nil +} + +func parsePortSpec(spec string) ([]int, error) { + ports := make([]int, 0, strings.Count(spec, ",")) + + for _, p := range strings.Split(spec, ",") { + port, err := strconv.Atoi(p) + if err != nil { + return nil, fmt.Errorf( + "invalid port %q found in %q: %v%w", + p, spec, err, ErrBadConfig, + ) + } + + ports = append(ports, port) + } + + return ports, nil +} + +func parseHostList(hostList string, ports []int) ([]string, []int, error) { + hostSpecs := strings.Split(hostList, ",") + + var ( + err error + defaultPorts []int + hostListPorts []int + ) + + if len(ports) == 0 { + if portSpec := os.Getenv("EDGEDB_PORT"); portSpec != "" { + defaultPorts, err = parsePortSpec(portSpec) + if err != nil { + return nil, nil, err + } + } else { + defaultPorts = []int{edgedbPort} + } + + defaultPorts, err = validatePortSpec(hostSpecs, defaultPorts) + if err != nil { + return nil, nil, err + } + } else { + ports, err = validatePortSpec(hostSpecs, ports) + if err != nil { + return nil, nil, err + } + } + + hosts := make([]string, 0, len(hostSpecs)) + for i, hostSpec := range hostSpecs { + addr, hostSpecPort := partition(hostSpec, ":") + hosts = append(hosts, addr) + + if len(ports) == 0 { + if hostSpecPort != "" { + port, err := strconv.Atoi(hostSpecPort) + if err != nil { + return nil, nil, fmt.Errorf( + "invalid port %q found in %q: %v%w", + hostSpecPort, hostSpec, err, ErrBadConfig, + ) + } + hostListPorts = append(hostListPorts, port) + } else { + hostListPorts = append(hostListPorts, defaultPorts[i]) + } + } + } + + if len(ports) == 0 { + ports = hostListPorts + } + + return hosts, ports, nil +} + +func partition(s, sep string) (string, string) { + list := strings.SplitN(s, sep, 2) + switch len(list) { + case 2: + return list[0], list[1] + case 1: + return list[0], "" + default: + return "", "" + } +} + +func pop(m map[string]string, key string) string { + v, ok := m[key] + if ok { + delete(m, key) + } + return v +} + +func parseConnectDSNAndArgs( + dsn string, + opts *Options, +) (*connConfig, error) { + usingCredentials := false + hosts := opts.Hosts + ports := opts.Ports + user := opts.User + password := opts.Password + database := opts.Database + + serverSettings := make(map[string]string, len(opts.ServerSettings)) + for k, v := range opts.ServerSettings { + serverSettings[k] = v + } + + if dsn != "" && strings.HasPrefix(dsn, "edgedb://") { + parsed, err := url.Parse(dsn) + if err != nil { + return nil, fmt.Errorf( + "could not parse %q: %v%w", dsn, err, ErrBadConfig) + } + + if parsed.Scheme != "edgedb" { + return nil, fmt.Errorf( + `invalid DSN: scheme is expected to be "edgedb", got %q%w`, + dsn, ErrBadConfig) + } + + if len(hosts) == 0 && parsed.Host != "" { + hosts, ports, err = parseHostList(parsed.Host, ports) + if err != nil { + return nil, err + } + } + + if database == "" { + database = strings.TrimLeft(parsed.Path, "/") + } + + if user == "" { + user = parsed.User.Username() + } + + if password == "" { + password, _ = parsed.User.Password() + } + + if parsed.RawQuery != "" { + q, err := url.ParseQuery(parsed.RawQuery) + if err != nil { + return nil, fmt.Errorf( + "invalid DSN %q: %v%w", dsn, err, ErrBadConfig) + } + + query := make(map[string]string, len(q)) + for key, val := range q { + query[key] = val[len(val)-1] + } + + if val := pop(query, "port"); val != "" && len(ports) == 0 { + ports, err = parsePortSpec(val) + if err != nil { + return nil, err + } + } + + if val := pop(query, "host"); val != "" && len(hosts) == 0 { + hosts, ports, err = parseHostList(val, ports) + if err != nil { + return nil, err + } + } + + if val := pop(query, "dbname"); database == "" { + database = val + } + + if val := pop(query, "database"); database == "" { + database = val + } + + if val := pop(query, "user"); user == "" { + user = val + } + + if val := pop(query, "password"); password == "" { + password = val + } + + for k, v := range query { + serverSettings[k] = v + } + } + } else if dsn != "" { + isIdentifier := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`) + if !isIdentifier.Match([]byte(dsn)) { + return nil, fmt.Errorf( + "dsn %q is neither a edgedb:// URI nor valid instance name%w", + dsn, ErrBadConfig, + ) + } + + usingCredentials = true + + u, err := usr.Current() + if err != nil { + return nil, err + } + + file := path.Join(u.HomeDir, ".edgedb", "credentials", dsn+".json") + creds, err := readCredentials(file) + if err != nil { + return nil, fmt.Errorf( + "cannot read credentials of instance %q: %v%w", + dsn, err, ErrClientFault, + ) + } + + if len(ports) == 0 { + ports = []int{creds.port} + } + + if user == "" { + user = creds.user + } + + if len(hosts) == 0 && creds.host != "" { + hosts = []string{creds.host} + } + + if password == "" { + password = creds.password + } + + if database == "" { + database = creds.database + } + } + + var err error + + if spec := os.Getenv("EDGEDB_HOST"); len(hosts) == 0 && spec != "" { + hosts, ports, err = parseHostList(spec, ports) + if err != nil { + return nil, err + } + } + + if len(hosts) == 0 { + if !usingCredentials { + hosts = append(hosts, defaultHosts...) + } + hosts = append(hosts, "localhost") + } + + if len(ports) == 0 { + if portSpec := os.Getenv("EDGEDB_PORT"); portSpec != "" { + ports, err = parsePortSpec(portSpec) + if err != nil { + return nil, err + } + } else { + ports = []int{edgedbPort} + } + } + + ports, err = validatePortSpec(hosts, ports) + if err != nil { + return nil, err + } + + if user == "" { + user = os.Getenv("EDGEDB_USER") + } + + if user == "" { + user = "edgedb" + } + + if password == "" { + password = os.Getenv("EDGEDB_PASSWORD") + } + + if database == "" { + database = os.Getenv("EDGEDB_DATABASE") + } + + if database == "" { + database = "edgedb" + } + + var addrs []dialArgs + for i := 0; i < len(hosts); i++ { + h := hosts[i] + p := ports[i] + + if strings.HasPrefix(h, "/") { + if !strings.Contains(h, ".s.EDGEDB.") { + h = path.Join(h, fmt.Sprintf(".s.EDGEDB.%v", p)) + } + addrs = append(addrs, dialArgs{"unix", h}) + } else { + addrs = append(addrs, dialArgs{ + "tcp", + fmt.Sprintf("%v:%v", h, p), + }) + } + } + + if len(addrs) == 0 { + return nil, fmt.Errorf( + "could not determine the database address to connect to%w", + ErrBadConfig, // TODO evaluate error type + ) + } + + cfg := &connConfig{ + addrs: addrs, + user: user, + password: password, + database: database, + connectTimeout: opts.ConnectTimeout, + serverSettings: serverSettings, + } + + return cfg, nil +} diff --git a/connutils_test.go b/connutils_test.go new file mode 100644 index 00000000..ea1cc1f4 --- /dev/null +++ b/connutils_test.go @@ -0,0 +1,411 @@ +// This source file is part of the EdgeDB open source project. +// +// Copyright 2020-present EdgeDB Inc. and the EdgeDB authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package edgedb + +import ( + "errors" + "os" + "path" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setenvmap(m map[string]string) func() { + funcs := make([]func(), 0, len(m)) + for key, val := range m { + funcs = append(funcs, setenv(key, val)) + } + + return func() { + for _, fn := range funcs { + fn() + } + } +} + +func setenv(key, val string) func() { + old, ok := os.LookupEnv(key) + + err := os.Setenv(key, val) + if err != nil { + panic(err) + } + + if ok { + return func() { + err = os.Setenv(key, old) + if err != nil { + panic(err) + } + } + } + + return func() { + err = os.Unsetenv(key) + if err != nil { + panic(err) + } + } +} + +func TestConUtils(t *testing.T) { + type Result struct { + cfg connConfig + err error + errMessage string + } + + tests := []struct { + name string + env map[string]string + dsn string + opts Options + expected Result + }{ + { + name: "host and user options", + opts: Options{ + User: "user", + Hosts: []string{"localhost"}, + }, + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{{"tcp", "localhost:5656"}}, + user: "user", + database: "edgedb", + serverSettings: map[string]string{}, + }, + }, + }, + { + name: "all environment variables", + env: map[string]string{ + "EDGEDB_USER": "user", + "EDGEDB_DATABASE": "testdb", + "EDGEDB_PASSWORD": "passw", + "EDGEDB_HOST": "host", + "EDGEDB_PORT": "123", + }, + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{{"tcp", "host:123"}}, + user: "user", + password: "passw", + database: "testdb", + serverSettings: map[string]string{}, + }, + }, + }, + { + name: "options are used before environment variables", + env: map[string]string{ + "EDGEDB_USER": "user", + "EDGEDB_DATABASE": "testdb", + "EDGEDB_PASSWORD": "passw", + "EDGEDB_HOST": "host", + "EDGEDB_PORT": "123", + }, + opts: Options{ + Hosts: []string{"host2"}, + Ports: []int{456}, + User: "user2", + Password: "passw2", + Database: "db2", + }, + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{{"tcp", "host2:456"}}, + user: "user2", + password: "passw2", + database: "db2", + serverSettings: map[string]string{}, + }, + }, + }, + { + name: "options are used before DSN string", + env: map[string]string{ + "EDGEDB_USER": "user", + "EDGEDB_DATABASE": "testdb", + "EDGEDB_PASSWORD": "passw", + "EDGEDB_HOST": "host", + "EDGEDB_PORT": "123", + "PGSSLMODE": "prefer", + }, + dsn: "edgedb://user3:123123@localhost/abcdef", + opts: Options{ + Hosts: []string{"host2"}, + Ports: []int{456}, + User: "user2", + Password: "passw2", + Database: "db2", + ServerSettings: map[string]string{"ssl": "False"}, + }, + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{{"tcp", "host2:456"}}, + user: "user2", + password: "passw2", + database: "db2", + serverSettings: map[string]string{"ssl": "False"}, + }, + }, + }, + { + name: "DSN is used before environment variables", + env: map[string]string{ + "EDGEDB_USER": "user", + "EDGEDB_DATABASE": "testdb", + "EDGEDB_PASSWORD": "passw", + "EDGEDB_HOST": "host", + "EDGEDB_PORT": "123", + }, + dsn: "edgedb://user3:123123@localhost:5555/abcdef", + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{{"tcp", "localhost:5555"}}, + user: "user3", + password: "123123", + database: "abcdef", + serverSettings: map[string]string{}, + }, + }, + }, + { + name: "DSN only", + dsn: "edgedb://user3:123123@localhost:5555/abcdef", + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{{"tcp", "localhost:5555"}}, + user: "user3", + password: "123123", + database: "abcdef", + serverSettings: map[string]string{}, + }, + }, + }, + { + name: "DSN with multiple hosts", + dsn: "edgedb://user@host1,host2/db", + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{ + {"tcp", "host1:5656"}, + {"tcp", "host2:5656"}, + }, + user: "user", + database: "db", + serverSettings: map[string]string{}, + }, + }, + }, + { + name: "DSN with multiple hosts and ports", + dsn: "edgedb://user@host1:1111,host2:2222/db", + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{ + {"tcp", "host1:1111"}, + {"tcp", "host2:2222"}, + }, + database: "db", + user: "user", + serverSettings: map[string]string{}, + }, + }, + }, + { + name: "environment variables with multiple hosts and ports", + env: map[string]string{ + "EDGEDB_HOST": "host1:1111,host2:2222", + "EDGEDB_USER": "foo", + }, + dsn: "edgedb:///db", + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{ + {"tcp", "host1:1111"}, + {"tcp", "host2:2222"}, + }, + database: "db", + user: "foo", + serverSettings: map[string]string{}, + }, + }, + }, + { + name: "query parameters with multiple hosts and ports", + env: map[string]string{ + "EDGEDB_USER": "foo", + }, + dsn: "edgedb:///db?host=host1:1111,host2:2222", + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{ + {"tcp", "host1:1111"}, + {"tcp", "host2:2222"}, + }, + database: "db", + user: "foo", + serverSettings: map[string]string{}, + }, + }, + }, + { + name: "options with multiple hosts", + env: map[string]string{ + "EDGEDB_USER": "foo", + }, + dsn: "edgedb:///db", + opts: Options{ + Hosts: []string{"host1", "host2"}, + }, + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{ + {"tcp", "host1:5656"}, + {"tcp", "host2:5656"}, + }, + user: "foo", + database: "db", + serverSettings: map[string]string{}, + }, + }, + }, + { + name: "DSN with server settings", + dsn: "edgedb://user3:123123@localhost:5555/" + + "abcdef?param=sss¶m=123&host=testhost&user=testuser" + + "&port=2222&database=testdb", + opts: Options{ + Hosts: []string{"127.0.0.1"}, + Ports: []int{888}, + User: "me", + Password: "ask", + Database: "db", + }, + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{ + {"tcp", "127.0.0.1:888"}, + }, + serverSettings: map[string]string{"param": "123"}, + user: "me", + password: "ask", + database: "db", + }, + }, + }, + { + name: "DSN and options server settings are merged", + dsn: "edgedb://user3:123123@localhost:5555/" + + "abcdef?param=sss¶m=123&host=testhost&user=testuser" + + "&port=2222&database=testdb", + opts: Options{ + Hosts: []string{"127.0.0.1"}, + Ports: []int{888}, + User: "me", + Password: "ask", + Database: "db", + ServerSettings: map[string]string{"aa": "bb"}, + }, + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{ + {"tcp", "127.0.0.1:888"}, + }, + serverSettings: map[string]string{ + "aa": "bb", + "param": "123", + }, + user: "me", + password: "ask", + database: "db", + }, + }, + }, + { + name: "DSN with unix socket", + dsn: "edgedb:///dbname?host=/unix_sock/test&user=spam", + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{{ + "unix", path.Join("/unix_sock/test", ".s.EDGEDB.5656"), + }}, + user: "spam", + database: "dbname", + serverSettings: map[string]string{}, + }, + }, + }, + { + name: "DSN requires edgedb scheme", + dsn: "pq:///dbname?host=/unix_sock/test&user=spam", + expected: Result{ + err: ErrBadConfig, + errMessage: "dsn " + + `"pq:///dbname?host=/unix_sock/test&user=spam" ` + + "is neither a edgedb:// URI nor valid instance name", + }, + }, + { + name: "host count must match port count", + dsn: "edgedb://host1,host2,host3/db", + opts: Options{ + Ports: []int{111, 222}, + }, + expected: Result{ + err: ErrInterfaceViolation, + errMessage: "could not match 2 port numbers to 3 hosts", + }, + }, + { + name: "DSN query parameter with unix socket", + dsn: "edgedb://user@?port=56226&host=%2Ftmp", + expected: Result{ + cfg: connConfig{ + addrs: []dialArgs{ + {"unix", path.Join("/tmp", ".s.EDGEDB.56226")}, + }, + user: "user", + database: "edgedb", + serverSettings: map[string]string{}, + }, + }, + }, + } + + for _, c := range tests { + t.Run(c.name, func(t *testing.T) { + cleanup := setenvmap(c.env) + defer cleanup() + + config, err := parseConnectDSNAndArgs(c.dsn, &c.opts) + + if c.expected.err != nil { + require.EqualError(t, err, c.expected.errMessage) + require.True(t, errors.Is(err, c.expected.err)) + assert.Nil(t, config) + } else { + require.Nil(t, err, "encountered err") + assert.Equal(t, c.expected.cfg, *config) + } + }) + } +} diff --git a/credentials.go b/credentials.go new file mode 100644 index 00000000..a4e9c7b4 --- /dev/null +++ b/credentials.go @@ -0,0 +1,107 @@ +// This source file is part of the EdgeDB open source project. +// +// Copyright 2020-present EdgeDB Inc. and the EdgeDB authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package edgedb + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "math" +) + +type credentials struct { + host string + port int + user string + database string + password string +} + +func readCredentials(path string) (*credentials, error) { + data, err := ioutil.ReadFile(path) + if err != nil { + return nil, fmt.Errorf( + "cannot read credentials at %q: %v%w", path, err, ErrBadConfig, + ) + } + + var values map[string]interface{} + if e := json.Unmarshal(data, &values); e != nil { + return nil, fmt.Errorf( + "cannot read credentials at %q: %v%w", path, e, ErrBadConfig, + ) + } + + creds, err := validateCredentials(values) + if err != nil { + return nil, fmt.Errorf( + "cannot read credentials at %q: %v%w", path, err, ErrBadConfig, + ) + } + + return creds, nil +} + +func validateCredentials(data map[string]interface{}) (*credentials, error) { + result := &credentials{} + + if val, ok := data["port"]; ok { + port, ok := val.(float64) + if !ok || port != math.Trunc(port) || port < 1 || port > 65535 { + return nil, fmt.Errorf("invalid `port` value%w", ErrBadConfig) + } + result.port = int(port) + } else { + result.port = 5656 + } + + user, ok := data["user"] + if !ok { + return nil, fmt.Errorf("`user` key is required%w", ErrBadConfig) + } + result.user, ok = user.(string) + if !ok { + return nil, fmt.Errorf("`user` must be a string%w", ErrBadConfig) + } + + if host, ok := data["host"]; ok { + result.host, ok = host.(string) + if !ok { + return nil, fmt.Errorf("`host` must be a string%w", ErrBadConfig) + } + } + + if database, ok := data["database"]; ok { + result.database, ok = database.(string) + if !ok { + return nil, fmt.Errorf( + "`database` must be a string%w", ErrBadConfig, + ) + } + } + + if password, ok := data["password"]; ok { + result.password, ok = password.(string) + if !ok { + return nil, fmt.Errorf( + "`password` must be a string%w", ErrBadConfig, + ) + } + } + + return result, nil +} diff --git a/credentials1.json b/credentials1.json new file mode 100644 index 00000000..eff948b4 --- /dev/null +++ b/credentials1.json @@ -0,0 +1,6 @@ +{ + "port": 10702, + "user": "test3n", + "password": "lZTBy1RVCfOpBAOwSCwIyBIR", + "database": "test3n" +} diff --git a/credentials_test.go b/credentials_test.go new file mode 100644 index 00000000..b3f51b06 --- /dev/null +++ b/credentials_test.go @@ -0,0 +1,80 @@ +// This source file is part of the EdgeDB open source project. +// +// Copyright 2020-present EdgeDB Inc. and the EdgeDB authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package edgedb + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCredentialsRead(t *testing.T) { + creds, err := readCredentials("credentials1.json") + require.Nil(t, err) + + expected := &credentials{ + database: "test3n", + password: "lZTBy1RVCfOpBAOwSCwIyBIR", + port: 10702, + user: "test3n", + } + + assert.Equal(t, expected, creds) +} + +func TestCredentialsEmpty(t *testing.T) { + creds, err := validateCredentials(map[string]interface{}{}) + assert.EqualError(t, err, "`user` key is required") + assert.True(t, errors.Is(err, ErrBadConfig)) + assert.Nil(t, creds) +} + +func TestCredentialsPort(t *testing.T) { + creds, err := validateCredentials(map[string]interface{}{ + "user": "u1", + "port": "1234", + }) + assert.EqualError(t, err, "invalid `port` value") + assert.True(t, errors.Is(err, ErrBadConfig)) + assert.Nil(t, creds) + + creds, err = validateCredentials(map[string]interface{}{ + "user": "u1", + "port": 0, + }) + assert.EqualError(t, err, "invalid `port` value") + assert.True(t, errors.Is(err, ErrBadConfig)) + assert.Nil(t, creds) + + creds, err = validateCredentials(map[string]interface{}{ + "user": "u1", + "port": -1, + }) + assert.EqualError(t, err, "invalid `port` value") + assert.True(t, errors.Is(err, ErrBadConfig)) + assert.Nil(t, creds) + + creds, err = validateCredentials(map[string]interface{}{ + "user": "u1", + "port": 65536, + }) + assert.EqualError(t, err, "invalid `port` value") + assert.True(t, errors.Is(err, ErrBadConfig)) + assert.Nil(t, creds) +} diff --git a/edgedb.go b/edgedb.go index 91ebe427..958e3e07 100644 --- a/edgedb.go +++ b/edgedb.go @@ -18,6 +18,7 @@ package edgedb import ( "context" + "log" "net" "github.com/edgedb/edgedb-go/cache" @@ -39,13 +40,27 @@ type baseConn struct { // ConnectOne establishes a connection to an EdgeDB server. func ConnectOne(ctx context.Context, opts Options) (*Conn, error) { // nolint:gocritic,lll + return ConnectOneDSN(ctx, "", opts) +} + +// ConnectOneDSN establishes a connection to an EdgeDB server. +func ConnectOneDSN( + ctx context.Context, + dsn string, + opts Options, // nolint:gocritic +) (*Conn, error) { conn := &baseConn{ typeIDCache: cache.New(1_000), inCodecCache: cache.New(1_000), outCodecCache: cache.New(1_000), } - if err := connectOne(ctx, &opts, conn); err != nil { + config, err := parseConnectDSNAndArgs(dsn, &opts) + if err != nil { + return nil, err + } + + if err := connectOne(ctx, config, conn); err != nil { return nil, err } @@ -53,20 +68,32 @@ func ConnectOne(ctx context.Context, opts Options) (*Conn, error) { // nolint:go } // connectOne expectes a singleConn that has a nil net.Conn. -func connectOne(ctx context.Context, opts *Options, conn *baseConn) error { - var d net.Dialer - netConn, err := d.DialContext(ctx, opts.network(), opts.address()) - if err != nil { - return err +func connectOne(ctx context.Context, cfg *connConfig, conn *baseConn) error { + var ( + d net.Dialer + err error + ) + + for _, addr := range cfg.addrs { // nolint:gocritic + // todo do error values need to be checked? + conn.conn, err = d.DialContext(ctx, addr.network, addr.address) + if err != nil { + log.Printf("while attempting connection %+v: %+v", addr, err) + continue + } + + err = conn.connect(ctx, cfg) + if err != nil { + _ = conn.conn.Close() + log.Printf("while attempting connection %+v: %+v", addr, err) + continue + } + + return nil } - conn.conn = netConn - err = conn.connect(ctx, opts) - if err != nil { - return err - } - - return nil + conn.conn = nil + return err } // Close the db connection @@ -190,7 +217,7 @@ func (c *baseConn) QueryOneJSON( } if len(*out) == 0 { - return ErrorZeroResults + return ErrZeroResults } return nil diff --git a/error.go b/error.go index 9578a2c9..0bb147f9 100644 --- a/error.go +++ b/error.go @@ -24,29 +24,32 @@ import ( ) var ( - // Error is wrapped by all errors returned from the server. - Error = errors.New("") - // todo error API (hierarchy and wrap all returned errors) + // Error is wrapped by all edgedb errors. + Error error = errors.New("") // ErrReleasedTwice is returned if a PoolConn is released more than once. ErrReleasedTwice = fmt.Errorf( "connection released more than once%w", Error, ) - // ErrorZeroResults is returned when a query has no results. - ErrorZeroResults = fmt.Errorf("zero results%w", Error) + // ErrZeroResults is returned when a query has no results. + ErrZeroResults = fmt.Errorf("zero results%w", Error) - // ErrorPoolClosed is returned by operations on closed pools. - ErrorPoolClosed error = fmt.Errorf("pool closed%w", Error) + // ErrPoolClosed is returned by operations on closed pools. + ErrPoolClosed error = fmt.Errorf("pool closed%w", Error) - // ErrorConnsInUse is returned when all connects are in use. - ErrorConnsInUse error = fmt.Errorf("all connections in use%w", Error) + // ErrContextExpired is returned when an expired context is used. + ErrContextExpired error = fmt.Errorf("context expired%w", Error) - // ErrorContextExpired is returned when an expired context is used. - ErrorContextExpired error = fmt.Errorf("context expired%w", Error) + // ErrBadConfig is wrapped + // when a function returning Options encounters an error. + ErrBadConfig error = fmt.Errorf("%w", Error) - // ErrorConfiguration is returned when invalid configuration is received. - ErrorConfiguration error = fmt.Errorf("%w", Error) + // ErrClientFault ... + ErrClientFault error = fmt.Errorf("%w", Error) + + // ErrInterfaceViolation ... + ErrInterfaceViolation error = fmt.Errorf("%w", ErrClientFault) ) func decodeError(buf *buff.Buff) error { @@ -85,6 +88,10 @@ func wrapAll(errs ...error) error { return nil } + if len(err.wrapped) == 1 { + return err.wrapped[0] + } + err.msg = err.wrapped[0].Error() for _, e := range err.wrapped[1:] { err.msg += "; " + e.Error() diff --git a/fallthrough.go b/fallthrough.go index 0d928ad8..4a425385 100644 --- a/fallthrough.go +++ b/fallthrough.go @@ -24,6 +24,13 @@ import ( "github.com/edgedb/edgedb-go/protocol/message" ) +var logMsgSeverityLookup map[uint8]string = map[uint8]string{ + 0x14: "DEBUG", + 0x28: "INFO", + 0x3c: "NOTICE", + 0x50: "WARNING", +} + func (c *baseConn) fallThrough(buf *buff.Buff) error { switch buf.MsgType { case message.ParameterStatus: @@ -31,9 +38,10 @@ func (c *baseConn) fallThrough(buf *buff.Buff) error { value := buf.PopString() c.serverSettings[name] = value case message.LogMessage: - severity := string([]byte{buf.PopUint8()}) + severity := logMsgSeverityLookup[buf.PopUint8()] code := buf.PopUint32() message := buf.PopString() + buf.Discard(2) // number of headers, assume 0 log.Println("SERVER MESSAGE", severity, code, message) default: return fmt.Errorf("unexpected message type: 0x%x", buf.MsgType) diff --git a/granular_flow.go b/granular_flow.go index 471888cf..45fef912 100644 --- a/granular_flow.go +++ b/granular_flow.go @@ -231,7 +231,7 @@ func (c *baseConn) execute( } tmp := out - err = ErrorZeroResults + err = ErrZeroResults for buf.Next() { switch buf.MsgType { case message.Data: @@ -296,7 +296,7 @@ func (c *baseConn) optimistic( } tmp := out - err = ErrorZeroResults + err = ErrZeroResults for buf.Next() { switch buf.MsgType { case message.Data: diff --git a/main_test.go b/main_test.go index 99a51d1c..39f4c712 100644 --- a/main_test.go +++ b/main_test.go @@ -22,7 +22,6 @@ import ( "encoding/json" "errors" "fmt" - "io/ioutil" "log" "math/rand" "os" @@ -53,16 +52,17 @@ func getLocalServer() error { return errors.New("credentials not found") } - data, err := ioutil.ReadFile(credFileName) + creds, err := readCredentials(credFileName) if err != nil { log.Printf("failed to read credentials file: %q", credFileName) return errors.New("credentials not found") } - err = json.Unmarshal(data, &opts) - if err != nil { - log.Printf("failed to parse credentials file: %q", credFileName) - return errors.New("credentials not found") + opts = Options{ + Ports: []int{creds.port}, + User: creds.user, + Password: creds.password, + Database: creds.database, } log.Print("using existing server") @@ -77,14 +77,20 @@ func startServer() (err error) { cmdName = fmt.Sprintf("%v-%v", cmdName, slot) } - cmd := exec.Command( - cmdName, + cmdArgs := []string{ "--temp-dir", "--testmode", "--echo-runtime-info", "--port=auto", "--auto-shutdown", - ) + `--bootstrap-command=` + + `CREATE SUPERUSER ROLE test { SET password := "shhh" }`, + } + + log.Println(cmdName, strings.Join(cmdArgs, " ")) + + cmd := exec.Command(cmdName, cmdArgs...) + cmd.Stderr = os.Stderr stdout, err := cmd.StdoutPipe() if err != nil { log.Fatal(err) @@ -99,6 +105,7 @@ func startServer() (err error) { scanner := bufio.NewScanner(stdout) for scanner.Scan() { text = scanner.Text() + fmt.Println(text) if strings.HasPrefix(text, "EDGEDB_SERVER_DATA:") { break } @@ -120,11 +127,11 @@ func startServer() (err error) { } opts = Options{ - Host: data.Host, - Port: data.Port, - User: "edgedb", + Hosts: []string{data.Host}, + Ports: []int{data.Port}, + User: "test", + Password: "shhh", Database: "edgedb", - admin: true, } log.Print("server started") @@ -135,6 +142,10 @@ func TestMain(m *testing.M) { var err error = nil code := 1 defer func() { + if p := recover(); p != nil { + log.Println(p) + } + if err != nil { log.Println("error while cleaning up: ", err) } @@ -154,7 +165,6 @@ func TestMain(m *testing.M) { if err != nil { panic(err) } - defer func() { e := conn.Close() if e != nil { @@ -172,7 +182,8 @@ func TestMain(m *testing.M) { var name string err = conn.QueryOne(ctx, query, &name) - if errors.Is(err, ErrorZeroResults) { + if errors.Is(err, ErrZeroResults) { + log.Println("setting up test db") executeOrPanic(` START MIGRATION TO { module default { @@ -206,6 +217,7 @@ func TestMain(m *testing.M) { user := {'user_with_password'} } `) + err = nil } rand.Seed(time.Now().Unix()) diff --git a/options.go b/options.go index 1f81926f..0bd64b83 100644 --- a/options.go +++ b/options.go @@ -17,80 +17,74 @@ package edgedb import ( - "fmt" - "net/url" - "strconv" - "strings" + "time" ) // Options for connecting to an EdgeDB server type Options struct { - Host string `json:"host"` - Port int `json:"port"` - User string `json:"user"` - Database string `json:"database"` - Password string `json:"password"` - admin bool + // Hosts is a slice of database host addresses as one of the following + // + // - an IP address or domain name + // + // - an absolute path to the directory + // containing the database server Unix-domain socket + // (not supported on Windows) + // + // If the slice is empty, the following will be tried, in order: + // + // - host address(es) parsed from the dsn argument + // + // - the value of the EDGEDB_HOST environment variable + // + // - on Unix, common directories used for EdgeDB Unix-domain sockets: + // "/run/edgedb" and "/var/run/edgedb" + // + // - "localhost" + Hosts []string - MaxConns int - MinConns int -} - -func (o *Options) network() string { - if o.admin { - return "unix" - } - return "tcp" -} + // Ports is a slice of port numbers to connect to at the server host + // (or Unix-domain socket file extension). + // + // Ports may either be: + // + // - the same length ans Hosts + // + // - a single port to be used all specified hosts + // + // - empty indicating the value parsed from the dsn argument + // should be used, or the value of the EDGEDB_PORT environment variable, + // or 5656 if neither is specified. + Ports []int -func (o *Options) address() string { - if o.admin { - return fmt.Sprintf("%v/.s.EDGEDB.admin.%v", o.Host, o.Port) - } + // User is the name of the database role used for authentication. + // If not specified, the value parsed from the dsn argument is used, + // or the value of the EDGEDB_USER environment variable, + // or the operating system name of the user running the application. + User string - host := o.Host - if host == "" { - host = "localhost" - } + // Database is the name of the database to connect to. + // If not specified, the value parsed from the dsn argument is used, + // or the value of the EDGEDB_DATABASE environment variable, + // or the operating system name of the user running the application. + Database string - port := o.Port - if port == 0 { - port = 5656 - } + // Password to be used for authentication, + // if the server requires one. If not specified, + // the value parsed from the dsn argument is used, + // or the value of the EDGEDB_PASSWORD environment variable. + // Note that the use of the environment variable is discouraged + // as other users and applications may be able to read it + // without needing specific privileges. + Password string - return fmt.Sprintf("%v:%v", host, port) -} - -// DSN parses a URI string into an Options struct -func DSN(dsn string) (opts Options, err error) { - parsed, err := url.Parse(dsn) - if err != nil { - return opts, err - } - - if parsed.Scheme != "edgedb" { - return opts, fmt.Errorf("dsn %q is not an edgedb:// URI", dsn) - } + // ConnectTimeout is used when establishing connections in the background. + ConnectTimeout time.Duration - var port int - if parsed.Port() == "" { - port = 5656 - } else { - port, err = strconv.Atoi(parsed.Port()) - if err != nil { - return opts, err - } - } + // MinConns determines the minimum number of connections. + MinConns int - host := strings.Split(parsed.Host, ":")[0] - db := strings.TrimLeft(parsed.Path, "/") - password, _ := parsed.User.Password() + // MaxConns determines the maximum number of connections. + MaxConns int - return Options{ - Host: host, - Port: port, - User: parsed.User.Username(), - Database: db, - Password: password, - }, nil + ServerSettings map[string]string } diff --git a/options_test.go b/options_test.go deleted file mode 100644 index 3f324d06..00000000 --- a/options_test.go +++ /dev/null @@ -1,84 +0,0 @@ -// This source file is part of the EdgeDB open source project. -// -// Copyright 2020-present EdgeDB Inc. and the EdgeDB authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package edgedb - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestParseHost(t *testing.T) { - o, err := DSN("edgedb://me@localhost:5656/somedb") - require.Nil(t, err) - assert.Equal(t, "localhost", o.Host) -} - -func TestParsePort(t *testing.T) { - o, err := DSN("edgedb://me@localhost:5656/somedb") - require.Nil(t, err) - assert.Equal(t, 5656, o.Port) -} - -func TestParseUser(t *testing.T) { - o, err := DSN("edgedb://me@localhost:5656/somedb") - require.Nil(t, err) - assert.Equal(t, "me", o.User) -} - -func TestParseDatabase(t *testing.T) { - o, err := DSN("edgedb://me@localhost:5656/somedb") - require.Nil(t, err) - assert.Equal(t, "somedb", o.Database) -} - -func TestParsePassword(t *testing.T) { - o, err := DSN("edgedb://me:secret@localhost:5656/somedb") - require.Nil(t, err) - assert.Equal(t, "secret", o.Password) -} - -func TestMissingPort(t *testing.T) { - o, err := DSN("edgedb://me@localhost/somedb") - require.Nil(t, err) - assert.Equal(t, 5656, o.Port) -} - -func TestDialHost(t *testing.T) { - o := Options{Host: "some.com", Port: 1234} - assert.Equal(t, "some.com:1234", o.address()) - - o = Options{Port: 1234} - assert.Equal(t, "localhost:1234", o.address()) - - o = Options{Host: "some.com"} - assert.Equal(t, "some.com:5656", o.address()) - - o = Options{} - assert.Equal(t, "localhost:5656", o.address()) -} - -func TestWrongScheme(t *testing.T) { - _, err := DSN("http://localhost") - assert.Equal( - t, - errors.New(`dsn "http://localhost" is not an edgedb:// URI`), - err, - ) -} diff --git a/pool.go b/pool.go index e04d4820..f297a01b 100644 --- a/pool.go +++ b/pool.go @@ -36,31 +36,49 @@ type Pool struct { // A buffered channel of structs representing unconnected capacity. potentialConns chan struct{} - opts *Options + maxConns int + minConns int + + cfg *connConfig + typeIDCache *cache.Cache inCodecCache *cache.Cache outCodecCache *cache.Cache } +// todo check connect tests in other clients + +// todo add connectDSN funcs + // Connect a pool of connections to a server. func Connect(ctx context.Context, opts Options) (*Pool, error) { // nolint:gocritic,lll + return ConnectDSN(ctx, "", opts) +} + +// ConnectDSN a pool of connections to a server. +func ConnectDSN(ctx context.Context, dsn string, opts Options) (*Pool, error) { // nolint:gocritic,lll if opts.MinConns < 1 { return nil, fmt.Errorf( "MinConns may not be less than 1, got: %v%w", - opts.MinConns, - ErrorConfiguration, + opts.MinConns, ErrBadConfig, ) } if opts.MaxConns < opts.MinConns { return nil, fmt.Errorf( - "MaxConns may not be less than MinConns%w", - ErrorConfiguration, + "MaxConns may not be less than MinConns%w", ErrBadConfig, ) } + cfg, err := parseConnectDSNAndArgs(dsn, &opts) + if err != nil { + return nil, err + } + pool := &Pool{ - opts: &opts, + maxConns: opts.MaxConns, + minConns: opts.MinConns, + cfg: cfg, freeConns: make(chan *baseConn, opts.MinConns), potentialConns: make(chan struct{}, opts.MaxConns), @@ -104,7 +122,7 @@ func (p *Pool) newConn(ctx context.Context) (*baseConn, error) { outCodecCache: p.outCodecCache, } - if err := connectOne(ctx, p.opts, conn); err != nil { + if err := connectOne(ctx, p.cfg, conn); err != nil { return nil, err } @@ -116,13 +134,13 @@ func (p *Pool) acquire(ctx context.Context) (*baseConn, error) { defer p.mu.RUnlock() if p.isClosed { - return nil, ErrorPoolClosed + return nil, ErrPoolClosed } // force do nothing if context is expired select { case <-ctx.Done(): - return nil, ErrorContextExpired + return nil, ErrContextExpired default: } @@ -144,7 +162,7 @@ func (p *Pool) acquire(ctx context.Context) (*baseConn, error) { } return conn, nil case <-ctx.Done(): - return nil, ErrorContextExpired + return nil, ErrContextExpired } } @@ -198,13 +216,13 @@ func (p *Pool) Close() error { defer p.mu.Unlock() if p.isClosed { - return ErrorPoolClosed + return ErrPoolClosed } p.isClosed = true wg := sync.WaitGroup{} - errs := make([]error, p.opts.MaxConns) - for i := 0; i < p.opts.MaxConns; i++ { + errs := make([]error, p.maxConns) + for i := 0; i < p.maxConns; i++ { select { case conn := <-p.freeConns: wg.Add(1) diff --git a/pool_test.go b/pool_test.go index 39afc867..a77ef8e5 100644 --- a/pool_test.go +++ b/pool_test.go @@ -19,7 +19,6 @@ package edgedb import ( "context" "errors" - "os" "testing" "time" @@ -61,7 +60,7 @@ func TestClosePoolConcurently(t *testing.T) { go func() { errs <- pool.Close() }() assert.Nil(t, <-errs) - assert.Equal(t, ErrorPoolClosed, <-errs) + assert.Equal(t, ErrPoolClosed, <-errs) } func TestConnectPoolMinConnGteZero(t *testing.T) { @@ -71,7 +70,7 @@ func TestConnectPoolMinConnGteZero(t *testing.T) { o := Options{MinConns: 0, MaxConns: 10} _, err := Connect(ctx, o) assert.EqualError(t, err, "MinConns may not be less than 1, got: 0") - assert.True(t, errors.Is(err, ErrorConfiguration)) + assert.True(t, errors.Is(err, ErrBadConfig)) } func TestConnectPoolMinConnLteMaxConn(t *testing.T) { @@ -81,7 +80,7 @@ func TestConnectPoolMinConnLteMaxConn(t *testing.T) { o := Options{MinConns: 5, MaxConns: 1} _, err := Connect(ctx, o) assert.EqualError(t, err, "MaxConns may not be less than MinConns") - assert.True(t, errors.Is(err, ErrorConfiguration)) + assert.True(t, errors.Is(err, ErrBadConfig)) } func TestAcquireFromClosedPool(t *testing.T) { @@ -92,7 +91,7 @@ func TestAcquireFromClosedPool(t *testing.T) { } conn, err := pool.Acquire(context.TODO()) - require.Equal(t, err, ErrorPoolClosed) + require.Equal(t, err, ErrPoolClosed) assert.Nil(t, conn) } @@ -107,14 +106,14 @@ func TestAcquireFreeConnFromPool(t *testing.T) { } func BenchmarkPoolAcquireRelease(b *testing.B) { - opts := &Options{MinConns: 2, MaxConns: 2} pool := &Pool{ - opts: opts, - freeConns: make(chan *baseConn, opts.MaxConns), - potentialConns: make(chan struct{}, opts.MaxConns), + maxConns: 2, + minConns: 2, + freeConns: make(chan *baseConn, 2), + potentialConns: make(chan struct{}, 2), } - for i := 0; i < opts.MaxConns; i++ { + for i := 0; i < pool.maxConns; i++ { pool.freeConns <- &baseConn{} } @@ -129,18 +128,26 @@ func BenchmarkPoolAcquireRelease(b *testing.B) { } func TestAcquirePotentialConnFromPool(t *testing.T) { - pool := &Pool{ - potentialConns: make(chan struct{}, 1), - opts: &opts, - } - pool.potentialConns <- struct{}{} + o := opts + o.MaxConns = 2 + o.MinConns = 1 + pool, err := Connect(context.TODO(), o) + require.Nil(t, err) + defer func() { + assert.Nil(t, pool.Close()) + }() - deadline := time.Now().Add(10 * time.Millisecond) - ctx, cancel := context.WithDeadline(context.Background(), deadline) - conn, err := pool.Acquire(ctx) - assert.True(t, errors.Is(err, os.ErrDeadlineExceeded)) - assert.Nil(t, conn) - cancel() + // free connection + a, err := pool.Acquire(context.TODO()) + require.Nil(t, err) + require.NotNil(t, a) + defer func() { assert.Nil(t, a.Release()) }() + + // potential connection + b, err := pool.Acquire(context.TODO()) + require.Nil(t, err) + require.NotNil(t, b) + defer func() { assert.Nil(t, b.Release()) }() } func TestPoolAcquireExpiredContext(t *testing.T) { @@ -155,7 +162,7 @@ func TestPoolAcquireExpiredContext(t *testing.T) { cancel() conn, err := pool.Acquire(ctx) - assert.Equal(t, err, ErrorContextExpired) + assert.Equal(t, err, ErrContextExpired) assert.Nil(t, conn) } @@ -165,20 +172,22 @@ func TestPoolAcquireThenContextExpires(t *testing.T) { deadline := time.Now().Add(10 * time.Millisecond) ctx, cancel := context.WithDeadline(context.Background(), deadline) conn, err := pool.Acquire(ctx) - assert.Equal(t, err, ErrorContextExpired) + assert.Equal(t, err, ErrContextExpired) assert.Nil(t, conn) cancel() } func TestClosePool(t *testing.T) { pool := &Pool{ + maxConns: 0, + minConns: 0, freeConns: make(chan *baseConn), potentialConns: make(chan struct{}), - opts: &Options{MaxConns: 0, MinConns: 0}, } + err := pool.Close() assert.Nil(t, err) err = pool.Close() - assert.Equal(t, err, ErrorPoolClosed) + assert.Equal(t, err, ErrPoolClosed) } diff --git a/query_test.go b/query_test.go index c6cd85d0..31e1c897 100644 --- a/query_test.go +++ b/query_test.go @@ -105,7 +105,7 @@ func TestQueryOneJSONZeroResults(t *testing.T) { var result []byte err := conn.QueryOneJSON(ctx, "SELECT {}", &result) - require.Equal(t, err, ErrorZeroResults) + require.Equal(t, err, ErrZeroResults) assert.Equal(t, []byte(nil), result) } @@ -123,7 +123,7 @@ func TestQueryOneZeroResults(t *testing.T) { var result int64 err := conn.QueryOne(ctx, "SELECT {}", &result) - assert.Equal(t, ErrorZeroResults, err) + assert.Equal(t, ErrZeroResults, err) } func TestError(t *testing.T) { diff --git a/tutorial_test.go b/tutorial_test.go index ff4672d5..c8a53e45 100644 --- a/tutorial_test.go +++ b/tutorial_test.go @@ -53,12 +53,11 @@ func TestTutorial(t *testing.T) { edb, err := ConnectOne( ctx, Options{ - Host: opts.Host, - Port: opts.Port, + Hosts: opts.Hosts, + Ports: opts.Ports, User: opts.User, Password: opts.Password, Database: dbName, - admin: opts.admin, }, ) if err != nil { diff --git a/unix.go b/unix.go new file mode 100644 index 00000000..d0eed059 --- /dev/null +++ b/unix.go @@ -0,0 +1,21 @@ +// This source file is part of the EdgeDB open source project. +// +// Copyright 2020-present EdgeDB Inc. and the EdgeDB authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !windows + +package edgedb + +var defaultHosts = []string{"/run/edgedb", "/var/run/edgedb"} diff --git a/windows.go b/windows.go new file mode 100644 index 00000000..7e3ed2aa --- /dev/null +++ b/windows.go @@ -0,0 +1,21 @@ +// This source file is part of the EdgeDB open source project. +// +// Copyright 2020-present EdgeDB Inc. and the EdgeDB authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build windows + +package edgedb + +var defaultHosts = []string{"localhost"}