diff --git a/dgraph/cmd/counter/increment.go b/dgraph/cmd/counter/increment.go index 8b18c11759a..31a6f606c56 100644 --- a/dgraph/cmd/counter/increment.go +++ b/dgraph/cmd/counter/increment.go @@ -31,7 +31,6 @@ import ( "github.com/dgraph-io/dgraph/x" "github.com/spf13/cobra" "github.com/spf13/viper" - "google.golang.org/grpc" ) var Increment x.SubCommand @@ -49,12 +48,15 @@ func init() { flag := Increment.Cmd.Flags() flag.String("addr", "localhost:9080", "Address of Dgraph alpha.") flag.Int("num", 1, "How many times to run.") - flag.Bool("ro", false, "Only read the counter value, don't update it.") - flag.Bool("be", false, "Read counter value without retrieving timestamp from Zero.") flag.Duration("wait", 0*time.Second, "How long to wait.") - flag.String("pred", "counter.val", "Predicate to use for storing the counter.") flag.String("user", "", "Username if login is required.") flag.String("password", "", "Password of the user.") + flag.String("pred", "counter.val", + "Predicate to use for storing the counter.") + flag.Bool("ro", false, + "Read-only. Read the counter value without updating it.") + flag.Bool("be", false, + "Best-effort. Read counter value without retrieving timestamp from Zero.") // TLS configuration x.RegisterClientTLSFlags(flag) } @@ -135,14 +137,17 @@ func run(conf *viper.Viper) { waitDur := conf.GetDuration("wait") num := conf.GetInt("num") - conn, err := grpc.Dial(addr, grpc.WithInsecure()) + tlsCfg, err := x.LoadClientTLSConfig(conf) + x.CheckfNoTrace(err) + + conn, err := x.SetupConnection(addr, tlsCfg, false) if err != nil { log.Fatal(err) } dc := api.NewDgraphClient(conn) dg := dgo.NewDgraphClient(dc) if user := conf.GetString("user"); len(user) > 0 { - x.Check(dg.Login(context.Background(), user, conf.GetString("password"))) + x.CheckfNoTrace(dg.Login(context.Background(), user, conf.GetString("password"))) } for num > 0 { diff --git a/tlstest/acl/acl_over_tls_test.go b/tlstest/acl/acl_over_tls_test.go index f9da21acf16..3bd3a4382ef 100644 --- a/tlstest/acl/acl_over_tls_test.go +++ b/tlstest/acl/acl_over_tls_test.go @@ -9,6 +9,7 @@ import ( "github.com/dgraph-io/dgo" "github.com/dgraph-io/dgo/protos/api" + "github.com/dgraph-io/dgraph/z" "github.com/golang/glog" "github.com/spf13/viper" "google.golang.org/grpc" @@ -98,7 +99,7 @@ func ExampleLoginOverTLS() { conf.Set("tls_cacert", "../tls/ca.crt") conf.Set("tls_server_name", "node") - dg, err := dgraphClientWithCerts(":9180", conf) + dg, err := dgraphClientWithCerts(z.SockAddr, conf) if err != nil { glog.Fatalf("Unable to get dgraph client: %v", err) } diff --git a/tlstest/certrequest/certrequest_test.go b/tlstest/certrequest/certrequest_test.go index 9602337892f..a715d2500a8 100644 --- a/tlstest/certrequest/certrequest_test.go +++ b/tlstest/certrequest/certrequest_test.go @@ -11,7 +11,7 @@ import ( ) func TestAccessOverPlaintext(t *testing.T) { - dg := z.DgraphClient(":9180") + dg := z.DgraphClient(z.SockAddr) err := dg.Alter(context.Background(), &api.Operation{DropAll: true}) require.Error(t, err, "The authentication handshake should have failed") } @@ -21,7 +21,7 @@ func TestAccessWithCaCert(t *testing.T) { conf.Set("tls_cacert", "../tls/ca.crt") conf.Set("tls_server_name", "node") - dg, err := z.DgraphClientWithCerts(":9180", conf) + dg, err := z.DgraphClientWithCerts(z.SockAddr, conf) require.NoError(t, err, "Unable to get dgraph client: %v", err) err = dg.Alter(context.Background(), &api.Operation{DropAll: true}) require.NoError(t, err, "Unable to perform dropall: %v", err) diff --git a/tlstest/certrequireandverify/certrequireandverify_test.go b/tlstest/certrequireandverify/certrequireandverify_test.go index 6041af9dbed..e8a7c5f9579 100644 --- a/tlstest/certrequireandverify/certrequireandverify_test.go +++ b/tlstest/certrequireandverify/certrequireandverify_test.go @@ -15,7 +15,7 @@ func TestAccessWithoutClientCert(t *testing.T) { conf.Set("tls_cacert", "../tls/ca.crt") conf.Set("tls_server_name", "node") - dg, err := z.DgraphClientWithCerts(":9180", conf) + dg, err := z.DgraphClientWithCerts(z.SockAddr, conf) require.NoError(t, err, "Unable to get dgraph client: %v", err) err = dg.Alter(context.Background(), &api.Operation{DropAll: true}) require.Error(t, err, "The authentication handshake should have failed") @@ -28,7 +28,7 @@ func TestAccessWithClientCert(t *testing.T) { conf.Set("tls_cert", "../tls/client.acl.crt") conf.Set("tls_key", "../tls/client.acl.key") - dg, err := z.DgraphClientWithCerts(":9180", conf) + dg, err := z.DgraphClientWithCerts(z.SockAddr, conf) require.NoError(t, err, "Unable to get dgraph client: %v", err) err = dg.Alter(context.Background(), &api.Operation{DropAll: true}) require.NoError(t, err, "Unable to perform dropall: %v", err) diff --git a/tlstest/certverifyifgiven/certverifyifgiven_test.go b/tlstest/certverifyifgiven/certverifyifgiven_test.go index 6b38fffcc98..886a7fb960e 100644 --- a/tlstest/certverifyifgiven/certverifyifgiven_test.go +++ b/tlstest/certverifyifgiven/certverifyifgiven_test.go @@ -15,7 +15,7 @@ func TestAccessWithoutClientCert(t *testing.T) { conf.Set("tls_cacert", "../tls/ca.crt") conf.Set("tls_server_name", "node") - dg, err := z.DgraphClientWithCerts(":9180", conf) + dg, err := z.DgraphClientWithCerts(z.SockAddr, conf) require.NoError(t, err, "Unable to get dgraph client: %v", err) err = dg.Alter(context.Background(), &api.Operation{DropAll: true}) require.NoError(t, err, "Unable to perform dropall: %v", err) @@ -28,7 +28,7 @@ func TestAccessWithClientCert(t *testing.T) { conf.Set("tls_cert", "../tls/client.acl.crt") conf.Set("tls_key", "../tls/client.acl.key") - dg, err := z.DgraphClientWithCerts(":9180", conf) + dg, err := z.DgraphClientWithCerts(z.SockAddr, conf) require.NoError(t, err, "Unable to get dgraph client: %v", err) err = dg.Alter(context.Background(), &api.Operation{DropAll: true}) require.NoError(t, err, "Unable to perform dropall: %v", err) diff --git a/x/tls_helper.go b/x/tls_helper.go index 584220867a8..78e4d2359ae 100644 --- a/x/tls_helper.go +++ b/x/tls_helper.go @@ -53,7 +53,8 @@ type TLSHelperConfig struct { } func RegisterClientTLSFlags(flag *pflag.FlagSet) { - flag.String("tls_cacert", "", "The CA Cert file used to verify server certificates.") + flag.String("tls_cacert", "", + "The CA Cert file used to verify server certificates. Required for enabling TLS.") flag.Bool("tls_use_system_ca", true, "Include System CA into CA Certs.") flag.String("tls_server_name", "", "Used to verify the server hostname.") flag.String("tls_cert", "", "(optional) The Cert file provided by the client to the server.") @@ -107,6 +108,14 @@ func LoadClientTLSConfig(v *viper.Viper) (*tls.Config, error) { } return &tlsCfg, nil + } else + // Attempt to determine if user specified *any* TLS option. Unfortunately and contrary to + // Viper's own documentation, there's no way to tell whether an option value came from a + // command-line option or a built-it default. + if v.GetString("tls_server_name") != "" || + v.GetString("tls_cert") != "" || + v.GetString("tls_key") != "" { + return nil, fmt.Errorf("--tls_cacert is required for enabling TLS") } return nil, nil }