Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TLS support to increment command #3257

Merged
merged 9 commits into from
Apr 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions dgraph/cmd/counter/increment.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"github.com/dgraph-io/dgraph/x"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"google.golang.org/grpc"
)

var Increment x.SubCommand
Expand All @@ -49,12 +48,15 @@ func init() {
flag := Increment.Cmd.Flags()
flag.String("addr", "localhost:9080", "Address of Dgraph alpha.")
flag.Int("num", 1, "How many times to run.")
flag.Bool("ro", false, "Only read the counter value, don't update it.")
flag.Bool("be", false, "Read counter value without retrieving timestamp from Zero.")
flag.Duration("wait", 0*time.Second, "How long to wait.")
flag.String("pred", "counter.val", "Predicate to use for storing the counter.")
flag.String("user", "", "Username if login is required.")
flag.String("password", "", "Password of the user.")
flag.String("pred", "counter.val",
"Predicate to use for storing the counter.")
flag.Bool("ro", false,
"Read-only. Read the counter value without updating it.")
flag.Bool("be", false,
"Best-effort. Read counter value without retrieving timestamp from Zero.")
// TLS configuration
x.RegisterClientTLSFlags(flag)
}
Expand Down Expand Up @@ -135,14 +137,17 @@ func run(conf *viper.Viper) {
waitDur := conf.GetDuration("wait")
num := conf.GetInt("num")

conn, err := grpc.Dial(addr, grpc.WithInsecure())
tlsCfg, err := x.LoadClientTLSConfig(conf)
x.CheckfNoTrace(err)

conn, err := x.SetupConnection(addr, tlsCfg, false)
if err != nil {
log.Fatal(err)
}
dc := api.NewDgraphClient(conn)
dg := dgo.NewDgraphClient(dc)
if user := conf.GetString("user"); len(user) > 0 {
x.Check(dg.Login(context.Background(), user, conf.GetString("password")))
x.CheckfNoTrace(dg.Login(context.Background(), user, conf.GetString("password")))
}

for num > 0 {
Expand Down
3 changes: 2 additions & 1 deletion tlstest/acl/acl_over_tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/dgraph-io/dgo"
"github.com/dgraph-io/dgo/protos/api"
"github.com/dgraph-io/dgraph/z"
"github.com/golang/glog"
"github.com/spf13/viper"
"google.golang.org/grpc"
Expand Down Expand Up @@ -98,7 +99,7 @@ func ExampleLoginOverTLS() {
conf.Set("tls_cacert", "../tls/ca.crt")
conf.Set("tls_server_name", "node")

dg, err := dgraphClientWithCerts(":9180", conf)
dg, err := dgraphClientWithCerts(z.SockAddr, conf)
if err != nil {
glog.Fatalf("Unable to get dgraph client: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions tlstest/certrequest/certrequest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func TestAccessOverPlaintext(t *testing.T) {
dg := z.DgraphClient(":9180")
dg := z.DgraphClient(z.SockAddr)
err := dg.Alter(context.Background(), &api.Operation{DropAll: true})
require.Error(t, err, "The authentication handshake should have failed")
}
Expand All @@ -21,7 +21,7 @@ func TestAccessWithCaCert(t *testing.T) {
conf.Set("tls_cacert", "../tls/ca.crt")
conf.Set("tls_server_name", "node")

dg, err := z.DgraphClientWithCerts(":9180", conf)
dg, err := z.DgraphClientWithCerts(z.SockAddr, conf)
require.NoError(t, err, "Unable to get dgraph client: %v", err)
err = dg.Alter(context.Background(), &api.Operation{DropAll: true})
require.NoError(t, err, "Unable to perform dropall: %v", err)
Expand Down
4 changes: 2 additions & 2 deletions tlstest/certrequireandverify/certrequireandverify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestAccessWithoutClientCert(t *testing.T) {
conf.Set("tls_cacert", "../tls/ca.crt")
conf.Set("tls_server_name", "node")

dg, err := z.DgraphClientWithCerts(":9180", conf)
dg, err := z.DgraphClientWithCerts(z.SockAddr, conf)
require.NoError(t, err, "Unable to get dgraph client: %v", err)
err = dg.Alter(context.Background(), &api.Operation{DropAll: true})
require.Error(t, err, "The authentication handshake should have failed")
Expand All @@ -28,7 +28,7 @@ func TestAccessWithClientCert(t *testing.T) {
conf.Set("tls_cert", "../tls/client.acl.crt")
conf.Set("tls_key", "../tls/client.acl.key")

dg, err := z.DgraphClientWithCerts(":9180", conf)
dg, err := z.DgraphClientWithCerts(z.SockAddr, conf)
require.NoError(t, err, "Unable to get dgraph client: %v", err)
err = dg.Alter(context.Background(), &api.Operation{DropAll: true})
require.NoError(t, err, "Unable to perform dropall: %v", err)
Expand Down
4 changes: 2 additions & 2 deletions tlstest/certverifyifgiven/certverifyifgiven_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestAccessWithoutClientCert(t *testing.T) {
conf.Set("tls_cacert", "../tls/ca.crt")
conf.Set("tls_server_name", "node")

dg, err := z.DgraphClientWithCerts(":9180", conf)
dg, err := z.DgraphClientWithCerts(z.SockAddr, conf)
require.NoError(t, err, "Unable to get dgraph client: %v", err)
err = dg.Alter(context.Background(), &api.Operation{DropAll: true})
require.NoError(t, err, "Unable to perform dropall: %v", err)
Expand All @@ -28,7 +28,7 @@ func TestAccessWithClientCert(t *testing.T) {
conf.Set("tls_cert", "../tls/client.acl.crt")
conf.Set("tls_key", "../tls/client.acl.key")

dg, err := z.DgraphClientWithCerts(":9180", conf)
dg, err := z.DgraphClientWithCerts(z.SockAddr, conf)
require.NoError(t, err, "Unable to get dgraph client: %v", err)
err = dg.Alter(context.Background(), &api.Operation{DropAll: true})
require.NoError(t, err, "Unable to perform dropall: %v", err)
Expand Down
11 changes: 10 additions & 1 deletion x/tls_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ type TLSHelperConfig struct {
}

func RegisterClientTLSFlags(flag *pflag.FlagSet) {
flag.String("tls_cacert", "", "The CA Cert file used to verify server certificates.")
flag.String("tls_cacert", "",
"The CA Cert file used to verify server certificates. Required for enabling TLS.")
flag.Bool("tls_use_system_ca", true, "Include System CA into CA Certs.")
flag.String("tls_server_name", "", "Used to verify the server hostname.")
flag.String("tls_cert", "", "(optional) The Cert file provided by the client to the server.")
Expand Down Expand Up @@ -107,6 +108,14 @@ func LoadClientTLSConfig(v *viper.Viper) (*tls.Config, error) {
}

return &tlsCfg, nil
} else
// Attempt to determine if user specified *any* TLS option. Unfortunately and contrary to
// Viper's own documentation, there's no way to tell whether an option value came from a
// command-line option or a built-it default.
if v.GetString("tls_server_name") != "" ||
v.GetString("tls_cert") != "" ||
v.GetString("tls_key") != "" {
return nil, fmt.Errorf("--tls_cacert is required for enabling TLS")
}
return nil, nil
}
Expand Down