diff --git a/graphql/admin/restore_status.go b/graphql/admin/restore_status.go index 265645e9708..9f095f90c51 100644 --- a/graphql/admin/restore_status.go +++ b/graphql/admin/restore_status.go @@ -19,6 +19,7 @@ package admin import ( "context" "encoding/json" + "fmt" "github.com/dgraph-io/dgraph/graphql/resolve" "github.com/dgraph-io/dgraph/graphql/schema" @@ -40,9 +41,25 @@ func unknownStatus(q schema.Query, err error) *resolve.Resolved { } } +func getRestoreStatusInput(q schema.Query) (int64, error) { + restoreId := q.ArgValue("restoreId") + switch v := restoreId.(type) { + case int64: + return v, nil + case json.Number: + return v.Int64() + default: + return -1, fmt.Errorf("Invalid value of restoreId") + } + +} + func resolveRestoreStatus(ctx context.Context, q schema.Query) *resolve.Resolved { - restoreId := int(q.ArgValue("restoreId").(int64)) - status, err := worker.ProcessRestoreStatus(ctx, restoreId) + restoreId, err := getRestoreStatusInput(q) + if err != nil { + return unknownStatus(q, err) + } + status, err := worker.ProcessRestoreStatus(ctx, int(restoreId)) if err != nil { return unknownStatus(q, err) } diff --git a/graphql/admin/restore_status_test.go b/graphql/admin/restore_status_test.go new file mode 100644 index 00000000000..94d6fa5c1fa --- /dev/null +++ b/graphql/admin/restore_status_test.go @@ -0,0 +1,38 @@ +package admin + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/dgraph-io/dgraph/graphql/schema" + "github.com/dgraph-io/dgraph/graphql/test" + "github.com/stretchr/testify/require" +) + +func TestRestoreStatus(t *testing.T) { + gqlSchema := test.LoadSchema(t, graphqlAdminSchema) + Query := `query restoreStatus($restoreId: Int!) { + restoreStatus(restoreId: $restoreId) { + status + errors + } + }` + variables := `{"restoreId": 2 }` + vars := make(map[string]interface{}) + d := json.NewDecoder(strings.NewReader(variables)) + d.UseNumber() + err := d.Decode(&vars) + require.NoError(t, err) + + op, err := gqlSchema.Operation( + &schema.Request{ + Query: Query, + Variables: vars, + }) + require.NoError(t, err) + gqlQuery := test.GetQuery(t, op) + v, err := getRestoreStatusInput(gqlQuery) + require.NoError(t, err) + require.IsType(t, int64(2), v, nil) +}