diff --git a/dgraph/cmd/alpha/run.go b/dgraph/cmd/alpha/run.go index b1dd1b68cf0..0f5efc47f8c 100644 --- a/dgraph/cmd/alpha/run.go +++ b/dgraph/cmd/alpha/run.go @@ -656,14 +656,13 @@ func run() { ChangeDataConf: Alpha.Conf.GetString("cdc"), } - aclKey, encKey := ee.GetKeys(Alpha.Conf) - if aclKey != nil { - opts.HmacSecret = aclKey - - acl := z.NewSuperFlag(Alpha.Conf.GetString("acl")).MergeAndCheckDefault(ee.AclDefaults) - opts.AccessJwtTtl = acl.GetDuration("access-ttl") - opts.RefreshJwtTtl = acl.GetDuration("refresh-ttl") + keys, err := ee.GetKeys(Alpha.Conf) + x.Check(err) + if keys.AclKey != nil { + opts.HmacSecret = keys.AclKey + opts.AccessJwtTtl = keys.AclAccessTtl + opts.RefreshJwtTtl = keys.AclRefreshTtl glog.Info("ACL secret key loaded successfully.") } @@ -702,7 +701,7 @@ func run() { Raft: raft, WhiteListedIPRanges: ips, StrictMutations: opts.MutationsMode == worker.StrictMutations, - AclEnabled: aclKey != nil, + AclEnabled: keys.AclKey != nil, AbortOlderThan: abortDur, StartTime: startTime, Ludicrous: ludicrous, @@ -723,7 +722,7 @@ func run() { // Set the directory for temporary buffers. z.SetTmpDir(x.WorkerConfig.TmpDir) - x.WorkerConfig.EncryptionKey = encKey + x.WorkerConfig.EncryptionKey = keys.EncKey setupCustomTokenizers() x.Init() diff --git a/dgraph/cmd/bulk/run.go b/dgraph/cmd/bulk/run.go index 1d6f4879bdd..709bbf77bb7 100644 --- a/dgraph/cmd/bulk/run.go +++ b/dgraph/cmd/bulk/run.go @@ -142,9 +142,13 @@ func run() { bopts := badger.DefaultOptions("").FromSuperFlag(BulkBadgerDefaults + cacheDefaults). FromSuperFlag(Bulk.Conf.GetString("badger")) + keys, err := ee.GetKeys(Bulk.Conf) + x.Check(err) + opt := options{ DataFiles: Bulk.Conf.GetString("files"), DataFormat: Bulk.Conf.GetString("format"), + EncryptionKey: keys.EncKey, SchemaFile: Bulk.Conf.GetString("schema"), GqlSchemaFile: Bulk.Conf.GetString("graphql_schema"), Encrypted: Bulk.Conf.GetBool("encrypted"), @@ -177,7 +181,6 @@ func run() { os.Exit(0) } - _, opt.EncryptionKey = ee.GetKeys(Bulk.Conf) if len(opt.EncryptionKey) == 0 { if opt.Encrypted || opt.EncryptedOut { fmt.Fprint(os.Stderr, "Must use --encryption or vault option(s).\n") @@ -187,11 +190,13 @@ func run() { requiredFlags := Bulk.Cmd.Flags().Changed("encrypted") && Bulk.Cmd.Flags().Changed("encrypted_out") if !requiredFlags { - fmt.Fprint(os.Stderr, "Must specify --encrypted and --encrypted_out when providing encryption key.\n") + fmt.Fprint(os.Stderr, + "Must specify --encrypted and --encrypted_out when providing encryption key.\n") os.Exit(1) } if !opt.Encrypted && !opt.EncryptedOut { - fmt.Fprint(os.Stderr, "Must set --encrypted and/or --encrypted_out to true when providing encryption key.\n") + fmt.Fprint(os.Stderr, + "Must set --encrypted and/or --encrypted_out to true when providing encryption key.\n") os.Exit(1) } diff --git a/dgraph/cmd/debug/run.go b/dgraph/cmd/debug/run.go index d4808d51f74..8899c42480b 100644 --- a/dgraph/cmd/debug/run.go +++ b/dgraph/cmd/debug/run.go @@ -885,14 +885,15 @@ func run() { } }() - var err error dir := opt.pdir isWal := false if len(dir) == 0 { dir = opt.wdir isWal = true } - _, opt.key = ee.GetKeys(Debug.Conf) + keys, err := ee.GetKeys(Debug.Conf) + x.Check(err) + opt.key = keys.EncKey if isWal { store, err := raftwal.InitEncrypted(dir, opt.key) diff --git a/dgraph/cmd/decrypt/decrypt.go b/dgraph/cmd/decrypt/decrypt.go index 2b9acddaf84..0db89ab6290 100644 --- a/dgraph/cmd/decrypt/decrypt.go +++ b/dgraph/cmd/decrypt/decrypt.go @@ -18,15 +18,14 @@ package decrypt import ( "compress/gzip" - "fmt" "io" - "log" "os" "strings" "github.com/dgraph-io/dgraph/ee" "github.com/dgraph-io/dgraph/ee/enc" "github.com/dgraph-io/dgraph/x" + "github.com/golang/glog" "github.com/spf13/cobra" ) @@ -57,17 +56,21 @@ func init() { ee.RegisterEncFlag(flag) } func run() { - opts := options{ - file: Decrypt.Conf.GetString("file"), - output: Decrypt.Conf.GetString("out"), + keys, err := ee.GetKeys(Decrypt.Conf) + x.Check(err) + if len(keys.EncKey) == 0 { + glog.Fatal("Error while reading encryption key: Key is empty") } - _, opts.keyfile = ee.GetKeys(Decrypt.Conf) - if len(opts.keyfile) == 0 { - log.Fatal("Error while reading encryption key: Key is empty") + + opts := options{ + file: Decrypt.Conf.GetString("file"), + output: Decrypt.Conf.GetString("out"), + keyfile: keys.EncKey, } + f, err := os.Open(opts.file) if err != nil { - log.Fatalf("Error opening file: %v\n", err) + glog.Fatalf("Error opening file: %v\n", err) } defer f.Close() reader, err := enc.GetReader(opts.keyfile, f) @@ -78,14 +81,14 @@ func run() { } outf, err := os.OpenFile(opts.output, os.O_WRONLY|os.O_CREATE, 0644) if err != nil { - log.Fatalf("Error while opening output file: %v\n", err) + glog.Fatalf("Error while opening output file: %v\n", err) } w := gzip.NewWriter(outf) - fmt.Printf("Decrypting %s\n", opts.file) - fmt.Printf("Writing to %v\n", opts.output) + glog.Infof("Decrypting %s\n", opts.file) + glog.Infof("Writing to %v\n", opts.output) _, err = io.Copy(w, reader) if err != nil { - log.Fatalf("Error while writing: %v\n", err) + glog.Fatalf("Error while writing: %v\n", err) } err = w.Flush() x.Check(err) @@ -93,5 +96,5 @@ func run() { x.Check(err) err = outf.Close() x.Check(err) - fmt.Println("Done.") + glog.Infof("Done.") } diff --git a/dgraph/cmd/live/run.go b/dgraph/cmd/live/run.go index a8d30735af5..8bb1d01c165 100644 --- a/dgraph/cmd/live/run.go +++ b/dgraph/cmd/live/run.go @@ -697,8 +697,11 @@ func run() error { } creds := z.NewSuperFlag(Live.Conf.GetString("creds")).MergeAndCheckDefault(x.DefaultCreds) + keys, err := ee.GetKeys(Live.Conf) + if err != nil { + return err + } - var err error x.PrintVersion() opt = options{ dataFiles: Live.Conf.GetString("files"), @@ -717,6 +720,7 @@ func run() error { ludicrousMode: Live.Conf.GetBool("ludicrous"), upsertPredicate: Live.Conf.GetString("upsertPredicate"), tmpDir: Live.Conf.GetString("tmp"), + key: keys.EncKey, } forceNs := Live.Conf.GetInt64("force-namespace") @@ -738,7 +742,6 @@ func run() error { z.SetTmpDir(opt.tmpDir) - _, opt.key = ee.GetKeys(Live.Conf) go func() { if err := http.ListenAndServe(opt.httpAddr, nil); err != nil { glog.Errorf("Error while starting HTTP server: %+v", err) diff --git a/ee/backup/run.go b/ee/backup/run.go index 95e8a841579..bcacfd38ba7 100644 --- a/ee/backup/run.go +++ b/ee/backup/run.go @@ -221,7 +221,11 @@ func (bw *bufWriter) Write(buf *z.Buffer) error { } func runExportBackup() error { - _, opt.key = ee.GetKeys(ExportBackup.Conf) + keys, err := ee.GetKeys(ExportBackup.Conf) + if err != nil { + return err + } + opt.key = keys.EncKey if opt.format != "json" && opt.format != "rdf" { return errors.Errorf("invalid format %s", opt.format) } diff --git a/ee/flags.go b/ee/flags.go index a39b7a4ed77..5c58f99c752 100644 --- a/ee/flags.go +++ b/ee/flags.go @@ -19,11 +19,21 @@ package ee import ( "fmt" "strings" + "time" + "github.com/dgraph-io/dgraph/x" "github.com/dgraph-io/ristretto/z" "github.com/spf13/pflag" ) +// Keys holds the configuration for ACL and encryption. +type Keys struct { + AclKey x.Sensitive + AclAccessTtl time.Duration + AclRefreshTtl time.Duration + EncKey x.Sensitive +} + const ( flagAcl = "acl" flagAclAccessTtl = "access-ttl" @@ -60,7 +70,7 @@ var ( flagAclAccessTtl, "6h", flagAclRefreshTtl, "30d", flagAclSecretFile, "") - encDefaults = fmt.Sprintf("%s=%s", flagEncKeyFile, "") + EncDefaults = fmt.Sprintf("%s=%s", flagEncKeyFile, "") ) func vaultDefaults(aclEnabled, encEnabled bool) string { @@ -126,12 +136,12 @@ func registerAclFlag(flag *pflag.FlagSet) { } func registerEncFlag(flag *pflag.FlagSet) { - helpText := z.NewSuperFlagHelp(encDefaults). + helpText := z.NewSuperFlagHelp(EncDefaults). Head("[Enterprise Feature] Encryption At Rest options"). Flag("key-file", "The file that stores the symmetric key of length 16, 24, or 32 bytes."+ "The key size determines the chosen AES cipher (AES-128, AES-192, and AES-256 respectively)."). String() - flag.String(flagEnc, encDefaults, helpText) + flag.String(flagEnc, EncDefaults, helpText) } func BuildEncFlag(filename string) string { diff --git a/ee/utils.go b/ee/keys.go similarity index 77% rename from ee/utils.go rename to ee/keys.go index 9cb317144df..8e4abc57f49 100644 --- a/ee/utils.go +++ b/ee/keys.go @@ -19,15 +19,15 @@ package ee import ( - "github.com/dgraph-io/dgraph/x" - "github.com/golang/glog" + "fmt" + "github.com/spf13/viper" ) // GetKeys returns the ACL and encryption keys as configured by the user // through the --acl, --encryption, and --vault flags. On OSS builds, -// this function exits with an error. -func GetKeys(config *viper.Viper) (x.Sensitive, x.Sensitive) { - glog.Exit("flags: acl / encryption is an enterprise-only feature") - return nil, nil +// this function always returns an error. +func GetKeys(config *viper.Viper) (*Keys, error) { + return nil, fmt.Errorf( + "flags: acl / encryption is an enterprise-only feature") } diff --git a/ee/keys_ee.go b/ee/keys_ee.go new file mode 100644 index 00000000000..c3e10f7c951 --- /dev/null +++ b/ee/keys_ee.go @@ -0,0 +1,67 @@ +// +build !oss + +/* + * Copyright 2021 Dgraph Labs, Inc. All rights reserved. + * + * Licensed under the Dgraph Community License (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * https://github.com/dgraph-io/dgraph/blob/master/licenses/DCL.txt + */ + +package ee + +import ( + "fmt" + "io/ioutil" + + "github.com/dgraph-io/ristretto/z" + "github.com/spf13/viper" +) + +// GetKeys returns the ACL and encryption keys as configured by the user +// through the --acl, --encryption, and --vault flags. On OSS builds, +// this function always returns an error. +func GetKeys(config *viper.Viper) (*Keys, error) { + keys := &Keys{} + var err error + + aclSuperFlag := z.NewSuperFlag(config.GetString("acl")).MergeAndCheckDefault(AclDefaults) + encSuperFlag := z.NewSuperFlag(config.GetString("encryption")).MergeAndCheckDefault(EncDefaults) + + // Get AclKey and EncKey from vault / acl / encryption SuperFlags + keys.AclKey, keys.EncKey = vaultGetKeys(config) + aclKeyFile := aclSuperFlag.GetPath(flagAclSecretFile) + if aclKeyFile != "" { + if keys.AclKey != nil { + return nil, fmt.Errorf("flags: ACL secret key set in both vault and acl flags") + } + if keys.AclKey, err = ioutil.ReadFile(aclKeyFile); err != nil { + return nil, fmt.Errorf("error reading ACL secret key from file: %s: %s", aclKeyFile, err) + } + } + if l := len(keys.AclKey); keys.AclKey != nil && l < 32 { + return nil, fmt.Errorf( + "ACL secret key must have length of at least 32 bytes, got %d bytes instead", l) + } + encKeyFile := encSuperFlag.GetPath(flagEncKeyFile) + if encKeyFile != "" { + if keys.EncKey != nil { + return nil, fmt.Errorf("flags: Encryption key set in both vault and encryption flags") + } + if keys.EncKey, err = ioutil.ReadFile(encKeyFile); err != nil { + return nil, fmt.Errorf("error reading encryption key from file: %s: %s", encKeyFile, err) + } + } + if l := len(keys.EncKey); keys.EncKey != nil && l != 16 && l != 32 && l != 64 { + return nil, fmt.Errorf( + "encryption key must have length of 16, 32, or 64 bytes, got %d bytes instead", l) + } + + // Get remaining keys + keys.AclAccessTtl = aclSuperFlag.GetDuration(flagAclAccessTtl) + keys.AclRefreshTtl = aclSuperFlag.GetDuration(flagAclRefreshTtl) + + return keys, nil +} diff --git a/ee/utils_ee.go b/ee/utils_ee.go deleted file mode 100644 index 66e015abffb..00000000000 --- a/ee/utils_ee.go +++ /dev/null @@ -1,60 +0,0 @@ -// +build !oss - -/* - * Copyright 2021 Dgraph Labs, Inc. All rights reserved. - * - * Licensed under the Dgraph Community License (the "License"); you - * may not use this file except in compliance with the License. You - * may obtain a copy of the License at - * - * https://github.com/dgraph-io/dgraph/blob/master/licenses/DCL.txt - */ - -package ee - -import ( - "io/ioutil" - - "github.com/dgraph-io/dgraph/x" - "github.com/dgraph-io/ristretto/z" - "github.com/golang/glog" - "github.com/spf13/viper" -) - -// GetKeys returns the ACL and encryption keys as configured by the user -// through the --acl, --encryption, and --vault flags. On OSS builds, -// this function exits with an error. -func GetKeys(config *viper.Viper) (x.Sensitive, x.Sensitive) { - aclSuperFlag := z.NewSuperFlag(config.GetString("acl")) - aclKey, encKey := vaultGetKeys(config) - var err error - - aclKeyFile := aclSuperFlag.GetPath("secret-file") - if aclKeyFile != "" { - if aclKey != nil { - glog.Exit("flags: ACL secret key set in both vault and acl flags") - } - if aclKey, err = ioutil.ReadFile(aclKeyFile); err != nil { - glog.Exitf("error reading ACL secret key from file: %s: %s", aclKeyFile, err) - } - } - if l := len(aclKey); aclKey != nil && l < 32 { - glog.Exitf("ACL secret key must have length of at least 32 bytes, got %d bytes instead", l) - } - - encSuperFlag := z.NewSuperFlag(config.GetString("encryption")).MergeAndCheckDefault(encDefaults) - encKeyFile := encSuperFlag.GetPath("key-file") - if encKeyFile != "" { - if encKey != nil { - glog.Exit("flags: Encryption key set in both vault and encryption") - } - if encKey, err = ioutil.ReadFile(encKeyFile); err != nil { - glog.Exitf("error reading encryption key from file: %s: %s", encKeyFile, err) - } - } - if l := len(encKey); encKey != nil && l != 16 && l != 32 && l != 64 { - glog.Exitf("encryption key must have length of 16, 32, or 64 bytes, got %d bytes instead", l) - } - - return aclKey, encKey -} diff --git a/testutil/backup.go b/testutil/backup.go index f1e9f904eb0..a98b348c878 100644 --- a/testutil/backup.go +++ b/testutil/backup.go @@ -50,12 +50,15 @@ func openDgraph(pdir string) (*badger.DB, error) { return nil, err } config.Set("encryption", ee.BuildEncFlag(KeyFile)) - _, encKey := ee.GetKeys(config) + keys, err := ee.GetKeys(config) + if err != nil { + return nil, err + } opt := badger.DefaultOptions(pdir). WithBlockCacheSize(10 * (1 << 20)). WithIndexCacheSize(10 * (1 << 20)). - WithEncryptionKey(encKey). + WithEncryptionKey(keys.EncKey). WithNamespaceOffset(x.NamespaceOffset) return badger.OpenManaged(opt) } diff --git a/worker/restore_map.go b/worker/restore_map.go index ad0a9d94472..19b5484db99 100644 --- a/worker/restore_map.go +++ b/worker/restore_map.go @@ -549,7 +549,10 @@ func RunMapper(req *pb.RestoreRequest, mapDir string) error { if err != nil { return errors.Wrapf(err, "unable to get encryption config") } - _, encKey := ee.GetKeys(cfg) + keys, err := ee.GetKeys(cfg) + if err != nil { + return err + } mapper := &mapper{ buf: newBuffer(), @@ -598,7 +601,7 @@ func RunMapper(req *pb.RestoreRequest, mapDir string) error { // Only restore the predicates that were assigned to this group at the time // of the last backup. file := filepath.Join(manifest.Path, backupName(manifest.ValidReadTs(), gid)) - br := readerFrom(h, file).WithEncryption(encKey).WithCompression(manifest.Compression) + br := readerFrom(h, file).WithEncryption(keys.EncKey).WithCompression(manifest.Compression) if br.err != nil { return errors.Wrap(br.err, "newBackupReader") }