diff --git a/cmd/rekor-server/app/root.go b/cmd/rekor-server/app/root.go index 4d60922e9..954e451c3 100644 --- a/cmd/rekor-server/app/root.go +++ b/cmd/rekor-server/app/root.go @@ -64,7 +64,7 @@ func init() { rootCmd.PersistentFlags().String("trillian_log_server.address", "127.0.0.1", "Trillian log server address") rootCmd.PersistentFlags().Uint16("trillian_log_server.port", 8090, "Trillian log server port") rootCmd.PersistentFlags().Uint("trillian_log_server.tlog_id", 0, "Trillian tree id") - rootCmd.PersistentFlags().String("trillian_log_server.sharding_config", "", "path to config file for inactive shards") + rootCmd.PersistentFlags().String("trillian_log_server.sharding_config", "", "path to config file for inactive shards, in JSON or YAML") hostname, err := os.Hostname() if err != nil { diff --git a/pkg/sharding/ranges.go b/pkg/sharding/ranges.go index 5b2772251..70485ec5d 100644 --- a/pkg/sharding/ranges.go +++ b/pkg/sharding/ranges.go @@ -18,6 +18,7 @@ package sharding import ( "context" "encoding/base64" + "encoding/json" "errors" "fmt" "io/ioutil" @@ -82,6 +83,10 @@ func logRangesFromPath(path string) (Ranges, error) { return Ranges{}, nil } if err := yaml.Unmarshal(contents, &ranges); err != nil { + // Try to use JSON + if jerr := json.Unmarshal(contents, &ranges); jerr == nil { + return ranges, nil + } return Ranges{}, err } return ranges, nil diff --git a/pkg/sharding/ranges_test.go b/pkg/sharding/ranges_test.go index 6e560bc4e..ae9d1123b 100644 --- a/pkg/sharding/ranges_test.go +++ b/pkg/sharding/ranges_test.go @@ -96,6 +96,32 @@ func TestLogRangesFromPath(t *testing.T) { } } +func TestLogRangesFromPathJSON(t *testing.T) { + contents := `[{"treeID": 0001, "treeLength": 3, "encodedPublicKey":"c2hhcmRpbmcK"}, {"treeID": 0002, "treeLength": 4}]` + file := filepath.Join(t.TempDir(), "sharding-config") + if err := ioutil.WriteFile(file, []byte(contents), 0644); err != nil { + t.Fatal(err) + } + expected := Ranges{ + { + TreeID: 1, + TreeLength: 3, + EncodedPublicKey: "c2hhcmRpbmcK", + }, { + TreeID: 2, + TreeLength: 4, + }, + } + + got, err := logRangesFromPath(file) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(expected, got) { + t.Fatalf("expected %v got %v", expected, got) + } +} + func TestLogRanges_ResolveVirtualIndex(t *testing.T) { lrs := LogRanges{ inactive: []LogRange{