Skip to content

Commit

Permalink
fix(ee): GetKeys should return an error (#7713) (#7797)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajeetdsouza authored May 10, 2021
1 parent bb0358e commit 7cc134a
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 104 deletions.
17 changes: 8 additions & 9 deletions dgraph/cmd/alpha/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -656,14 +656,13 @@ func run() {
ChangeDataConf: Alpha.Conf.GetString("cdc"),
}

aclKey, encKey := ee.GetKeys(Alpha.Conf)
if aclKey != nil {
opts.HmacSecret = aclKey

acl := z.NewSuperFlag(Alpha.Conf.GetString("acl")).MergeAndCheckDefault(ee.AclDefaults)
opts.AccessJwtTtl = acl.GetDuration("access-ttl")
opts.RefreshJwtTtl = acl.GetDuration("refresh-ttl")
keys, err := ee.GetKeys(Alpha.Conf)
x.Check(err)

if keys.AclKey != nil {
opts.HmacSecret = keys.AclKey
opts.AccessJwtTtl = keys.AclAccessTtl
opts.RefreshJwtTtl = keys.AclRefreshTtl
glog.Info("ACL secret key loaded successfully.")
}

Expand Down Expand Up @@ -702,7 +701,7 @@ func run() {
Raft: raft,
WhiteListedIPRanges: ips,
StrictMutations: opts.MutationsMode == worker.StrictMutations,
AclEnabled: aclKey != nil,
AclEnabled: keys.AclKey != nil,
AbortOlderThan: abortDur,
StartTime: startTime,
Ludicrous: ludicrous,
Expand All @@ -723,7 +722,7 @@ func run() {
// Set the directory for temporary buffers.
z.SetTmpDir(x.WorkerConfig.TmpDir)

x.WorkerConfig.EncryptionKey = encKey
x.WorkerConfig.EncryptionKey = keys.EncKey

setupCustomTokenizers()
x.Init()
Expand Down
11 changes: 8 additions & 3 deletions dgraph/cmd/bulk/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,13 @@ func run() {

bopts := badger.DefaultOptions("").FromSuperFlag(BulkBadgerDefaults + cacheDefaults).
FromSuperFlag(Bulk.Conf.GetString("badger"))
keys, err := ee.GetKeys(Bulk.Conf)
x.Check(err)

opt := options{
DataFiles: Bulk.Conf.GetString("files"),
DataFormat: Bulk.Conf.GetString("format"),
EncryptionKey: keys.EncKey,
SchemaFile: Bulk.Conf.GetString("schema"),
GqlSchemaFile: Bulk.Conf.GetString("graphql_schema"),
Encrypted: Bulk.Conf.GetBool("encrypted"),
Expand Down Expand Up @@ -177,7 +181,6 @@ func run() {
os.Exit(0)
}

_, opt.EncryptionKey = ee.GetKeys(Bulk.Conf)
if len(opt.EncryptionKey) == 0 {
if opt.Encrypted || opt.EncryptedOut {
fmt.Fprint(os.Stderr, "Must use --encryption or vault option(s).\n")
Expand All @@ -187,11 +190,13 @@ func run() {
requiredFlags := Bulk.Cmd.Flags().Changed("encrypted") &&
Bulk.Cmd.Flags().Changed("encrypted_out")
if !requiredFlags {
fmt.Fprint(os.Stderr, "Must specify --encrypted and --encrypted_out when providing encryption key.\n")
fmt.Fprint(os.Stderr,
"Must specify --encrypted and --encrypted_out when providing encryption key.\n")
os.Exit(1)
}
if !opt.Encrypted && !opt.EncryptedOut {
fmt.Fprint(os.Stderr, "Must set --encrypted and/or --encrypted_out to true when providing encryption key.\n")
fmt.Fprint(os.Stderr,
"Must set --encrypted and/or --encrypted_out to true when providing encryption key.\n")
os.Exit(1)
}

Expand Down
5 changes: 3 additions & 2 deletions dgraph/cmd/debug/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -885,14 +885,15 @@ func run() {
}
}()

var err error
dir := opt.pdir
isWal := false
if len(dir) == 0 {
dir = opt.wdir
isWal = true
}
_, opt.key = ee.GetKeys(Debug.Conf)
keys, err := ee.GetKeys(Debug.Conf)
x.Check(err)
opt.key = keys.EncKey

if isWal {
store, err := raftwal.InitEncrypted(dir, opt.key)
Expand Down
31 changes: 17 additions & 14 deletions dgraph/cmd/decrypt/decrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ package decrypt

import (
"compress/gzip"
"fmt"
"io"
"log"
"os"
"strings"

"github.com/dgraph-io/dgraph/ee"
"github.com/dgraph-io/dgraph/ee/enc"
"github.com/dgraph-io/dgraph/x"
"github.com/golang/glog"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -57,17 +56,21 @@ func init() {
ee.RegisterEncFlag(flag)
}
func run() {
opts := options{
file: Decrypt.Conf.GetString("file"),
output: Decrypt.Conf.GetString("out"),
keys, err := ee.GetKeys(Decrypt.Conf)
x.Check(err)
if len(keys.EncKey) == 0 {
glog.Fatal("Error while reading encryption key: Key is empty")
}
_, opts.keyfile = ee.GetKeys(Decrypt.Conf)
if len(opts.keyfile) == 0 {
log.Fatal("Error while reading encryption key: Key is empty")

opts := options{
file: Decrypt.Conf.GetString("file"),
output: Decrypt.Conf.GetString("out"),
keyfile: keys.EncKey,
}

f, err := os.Open(opts.file)
if err != nil {
log.Fatalf("Error opening file: %v\n", err)
glog.Fatalf("Error opening file: %v\n", err)
}
defer f.Close()
reader, err := enc.GetReader(opts.keyfile, f)
Expand All @@ -78,20 +81,20 @@ func run() {
}
outf, err := os.OpenFile(opts.output, os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
log.Fatalf("Error while opening output file: %v\n", err)
glog.Fatalf("Error while opening output file: %v\n", err)
}
w := gzip.NewWriter(outf)
fmt.Printf("Decrypting %s\n", opts.file)
fmt.Printf("Writing to %v\n", opts.output)
glog.Infof("Decrypting %s\n", opts.file)
glog.Infof("Writing to %v\n", opts.output)
_, err = io.Copy(w, reader)
if err != nil {
log.Fatalf("Error while writing: %v\n", err)
glog.Fatalf("Error while writing: %v\n", err)
}
err = w.Flush()
x.Check(err)
err = w.Close()
x.Check(err)
err = outf.Close()
x.Check(err)
fmt.Println("Done.")
glog.Infof("Done.")
}
7 changes: 5 additions & 2 deletions dgraph/cmd/live/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -697,8 +697,11 @@ func run() error {
}

creds := z.NewSuperFlag(Live.Conf.GetString("creds")).MergeAndCheckDefault(x.DefaultCreds)
keys, err := ee.GetKeys(Live.Conf)
if err != nil {
return err
}

var err error
x.PrintVersion()
opt = options{
dataFiles: Live.Conf.GetString("files"),
Expand All @@ -717,6 +720,7 @@ func run() error {
ludicrousMode: Live.Conf.GetBool("ludicrous"),
upsertPredicate: Live.Conf.GetString("upsertPredicate"),
tmpDir: Live.Conf.GetString("tmp"),
key: keys.EncKey,
}

forceNs := Live.Conf.GetInt64("force-namespace")
Expand All @@ -738,7 +742,6 @@ func run() error {

z.SetTmpDir(opt.tmpDir)

_, opt.key = ee.GetKeys(Live.Conf)
go func() {
if err := http.ListenAndServe(opt.httpAddr, nil); err != nil {
glog.Errorf("Error while starting HTTP server: %+v", err)
Expand Down
6 changes: 5 additions & 1 deletion ee/backup/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,11 @@ func (bw *bufWriter) Write(buf *z.Buffer) error {
}

func runExportBackup() error {
_, opt.key = ee.GetKeys(ExportBackup.Conf)
keys, err := ee.GetKeys(ExportBackup.Conf)
if err != nil {
return err
}
opt.key = keys.EncKey
if opt.format != "json" && opt.format != "rdf" {
return errors.Errorf("invalid format %s", opt.format)
}
Expand Down
16 changes: 13 additions & 3 deletions ee/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,21 @@ package ee
import (
"fmt"
"strings"
"time"

"github.com/dgraph-io/dgraph/x"
"github.com/dgraph-io/ristretto/z"
"github.com/spf13/pflag"
)

// Keys holds the configuration for ACL and encryption.
type Keys struct {
AclKey x.Sensitive
AclAccessTtl time.Duration
AclRefreshTtl time.Duration
EncKey x.Sensitive
}

const (
flagAcl = "acl"
flagAclAccessTtl = "access-ttl"
Expand Down Expand Up @@ -60,7 +70,7 @@ var (
flagAclAccessTtl, "6h",
flagAclRefreshTtl, "30d",
flagAclSecretFile, "")
encDefaults = fmt.Sprintf("%s=%s", flagEncKeyFile, "")
EncDefaults = fmt.Sprintf("%s=%s", flagEncKeyFile, "")
)

func vaultDefaults(aclEnabled, encEnabled bool) string {
Expand Down Expand Up @@ -126,12 +136,12 @@ func registerAclFlag(flag *pflag.FlagSet) {
}

func registerEncFlag(flag *pflag.FlagSet) {
helpText := z.NewSuperFlagHelp(encDefaults).
helpText := z.NewSuperFlagHelp(EncDefaults).
Head("[Enterprise Feature] Encryption At Rest options").
Flag("key-file", "The file that stores the symmetric key of length 16, 24, or 32 bytes."+
"The key size determines the chosen AES cipher (AES-128, AES-192, and AES-256 respectively).").
String()
flag.String(flagEnc, encDefaults, helpText)
flag.String(flagEnc, EncDefaults, helpText)
}

func BuildEncFlag(filename string) string {
Expand Down
12 changes: 6 additions & 6 deletions ee/utils.go → ee/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
package ee

import (
"github.com/dgraph-io/dgraph/x"
"github.com/golang/glog"
"fmt"

"github.com/spf13/viper"
)

// GetKeys returns the ACL and encryption keys as configured by the user
// through the --acl, --encryption, and --vault flags. On OSS builds,
// this function exits with an error.
func GetKeys(config *viper.Viper) (x.Sensitive, x.Sensitive) {
glog.Exit("flags: acl / encryption is an enterprise-only feature")
return nil, nil
// this function always returns an error.
func GetKeys(config *viper.Viper) (*Keys, error) {
return nil, fmt.Errorf(
"flags: acl / encryption is an enterprise-only feature")
}
67 changes: 67 additions & 0 deletions ee/keys_ee.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// +build !oss

/*
* Copyright 2021 Dgraph Labs, Inc. All rights reserved.
*
* Licensed under the Dgraph Community License (the "License"); you
* may not use this file except in compliance with the License. You
* may obtain a copy of the License at
*
* https://github.com/dgraph-io/dgraph/blob/master/licenses/DCL.txt
*/

package ee

import (
"fmt"
"io/ioutil"

"github.com/dgraph-io/ristretto/z"
"github.com/spf13/viper"
)

// GetKeys returns the ACL and encryption keys as configured by the user
// through the --acl, --encryption, and --vault flags. On OSS builds,
// this function always returns an error.
func GetKeys(config *viper.Viper) (*Keys, error) {
keys := &Keys{}
var err error

aclSuperFlag := z.NewSuperFlag(config.GetString("acl")).MergeAndCheckDefault(AclDefaults)
encSuperFlag := z.NewSuperFlag(config.GetString("encryption")).MergeAndCheckDefault(EncDefaults)

// Get AclKey and EncKey from vault / acl / encryption SuperFlags
keys.AclKey, keys.EncKey = vaultGetKeys(config)
aclKeyFile := aclSuperFlag.GetPath(flagAclSecretFile)
if aclKeyFile != "" {
if keys.AclKey != nil {
return nil, fmt.Errorf("flags: ACL secret key set in both vault and acl flags")
}
if keys.AclKey, err = ioutil.ReadFile(aclKeyFile); err != nil {
return nil, fmt.Errorf("error reading ACL secret key from file: %s: %s", aclKeyFile, err)
}
}
if l := len(keys.AclKey); keys.AclKey != nil && l < 32 {
return nil, fmt.Errorf(
"ACL secret key must have length of at least 32 bytes, got %d bytes instead", l)
}
encKeyFile := encSuperFlag.GetPath(flagEncKeyFile)
if encKeyFile != "" {
if keys.EncKey != nil {
return nil, fmt.Errorf("flags: Encryption key set in both vault and encryption flags")
}
if keys.EncKey, err = ioutil.ReadFile(encKeyFile); err != nil {
return nil, fmt.Errorf("error reading encryption key from file: %s: %s", encKeyFile, err)
}
}
if l := len(keys.EncKey); keys.EncKey != nil && l != 16 && l != 32 && l != 64 {
return nil, fmt.Errorf(
"encryption key must have length of 16, 32, or 64 bytes, got %d bytes instead", l)
}

// Get remaining keys
keys.AclAccessTtl = aclSuperFlag.GetDuration(flagAclAccessTtl)
keys.AclRefreshTtl = aclSuperFlag.GetDuration(flagAclRefreshTtl)

return keys, nil
}
Loading

0 comments on commit 7cc134a

Please sign in to comment.