diff --git a/backend/services/Makefile b/backend/services/Makefile index c3c0f33b66a..1af216d4971 100644 --- a/backend/services/Makefile +++ b/backend/services/Makefile @@ -162,11 +162,11 @@ run_host_service_dev: build run_scaling_service: build $(info running $(BUILD_FOLDER)/scaling-service with localdev configuration...) - APP_ENV=localdev $(BUILD_FOLDER)/scaling-service + APP_ENV=localdev $(BUILD_FOLDER)/scaling-service -nocleanup run_scaling_service_localdevwithdb: build $(info running $(BUILD_FOLDER)/scaling-service with localdevwithdb configuration...) - APP_ENV=localdevwithdb $(BUILD_FOLDER)/scaling-service + APP_ENV=localdevwithdb $(BUILD_FOLDER)/scaling-service -nocleanup run_scaling_service_dev: build $(info running $(BUILD_FOLDER)/scaling-service with dev (deployed) configuration...) diff --git a/backend/services/scaling-service/cleanup.go b/backend/services/scaling-service/cleanup.go new file mode 100644 index 00000000000..f0e7a56ee65 --- /dev/null +++ b/backend/services/scaling-service/cleanup.go @@ -0,0 +1,133 @@ +// Copyright (c) 2022 Whist Technologies, Inc. + +package main + +import ( + "context" + "sync" + "time" + + "github.com/whisthq/whist/backend/services/scaling-service/dbclient" + "github.com/whisthq/whist/backend/services/scaling-service/hosts" + "github.com/whisthq/whist/backend/services/subscriptions" + logger "github.com/whisthq/whist/backend/services/whistlogger" +) + +var db dbclient.DBClient + +// CleanRegion starts the unresponsive instance cleanup thread for a particular +// region and returns a function that can be used to stop it. You should call +// CleanRegion once per region, but it doesn't really make sense to call it +// more than once per region. +// +// The stop function blocks until all in-progress cleaning operations have +// completed. Consider calling this method in its own goroutine like so: +// +// var cleaner *Cleaner +// var wg sync.WaitGroup +// +// wg.Add(1) +// +// go func() { +// defer wg.Done() +// cleaner.Stop() +// }() +// +// wg.Wait() +func CleanRegion(client subscriptions.WhistGraphQLClient, h hosts.HostHandler, d time.Duration) func() { + stop := make(chan struct{}) + ticker := time.NewTicker(d) + var wg sync.WaitGroup + + // Don't bother adding this goroutine to the cleaner's wait group. It will + // finish as soon as the stop channel is closed. + go func() { + for { + select { + case <-ticker.C: + wg.Add(1) + + go func() { + defer wg.Done() + + // TODO: Make the deadline more configurable. + deadline := time.Now().Add(5 * time.Minute) + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + do(ctx, client, h) + }() + case <-stop: + return + } + } + }() + + return func() { + ticker.Stop() + close(stop) + wg.Wait() + } +} + +// do marks all unresponsive instances as TERMINATING in the database before +// subsequently terminating them and finally removing them from the database +// altogether. +func do(ctx context.Context, client subscriptions.WhistGraphQLClient, h hosts.HostHandler) { + region := h.GetRegion() + maxAge := time.Now().Add(-150 * time.Second) + ids, err := db.LockBrokenInstances(ctx, client, region, maxAge) + + if err != nil { + logger.Errorf("failed to mark unresponsive instances as TERMINATING in "+ + "region %s: %s", region, err) + return + } else if len(ids) < 1 { + logger.Debugf("Didn't find any unresponsive instances in region %s", + region) + return + } + + if err := h.SpinDownInstances(ctx, ids); err != nil { + logger.Errorf("failed to terminate unresponsive instances in region %s:"+ + "%s", region, err) + logger.Errorf("please verify that the instances in %s with the following "+ + "instance IDs have been terminated and then remove them from the "+ + "database: %s", region, ids) + return + } + + deleted, err := db.TerminateLockedInstances(ctx, client, region, ids) + + if err != nil { + logger.Errorf("failed to remove unresponsive instances in %s from the "+ + "database: %s", region, err) + return + } + + if !equal(deleted, ids) { + logger.Errorf("some %s instance rows were not deleted: requested "+ + "%v, got %v", region, ids, deleted) + } else { + logger.Info("Successfully removed the following unresponsive instance "+ + "rows from the database for %s:", region, deleted) + } +} + +func equal(u, v []string) bool { + if len(u) != len(v) { + return false + } + + same := true + set := make(map[string]struct{}, len(u)) + + for _, s := range u { + set[s] = struct{}{} + } + + for _, s := range v { + _, ok := set[s] + same = same && ok + } +} diff --git a/backend/services/scaling-service/dbclient/cleanup.go b/backend/services/scaling-service/dbclient/cleanup.go new file mode 100644 index 00000000000..910a769d05f --- /dev/null +++ b/backend/services/scaling-service/dbclient/cleanup.go @@ -0,0 +1,66 @@ +package dbclient + +import ( + "context" + "time" + + "github.com/hasura/go-graphql-client" + "github.com/whisthq/whist/backend/services/subscriptions" +) + +// LockBrokenInstances sets the column of each row of the database corresponding +// to an instance that hasn't been updated since maxAge to TERMINATING. It +// returns the instance ID of every such instance. +func (c *DBClient) LockBrokenInstances(ctx context.Context, client subscriptions.WhistGraphQLClient, region string, maxAge time.Time) ([]string, error) { + var m subscriptions.LockBrokenInstances + vars := map[string]interface{}{ + "maxAge": timestamptz(maxAge), + "region": graphql.String(region), + } + + if err := client.Mutate(ctx, &m, vars); err != nil { + return nil, err + } + + ids := make([]string, 0, m.Response.Count) + + for _, host := range m.Response.Hosts { + ids = append(ids, string(host.ID)) + } + + return ids, nil +} + +// TerminatedLockedInstances removes the requested rows, all of whose status +// columns should have TERMINATING, corresponding to unresponsive instances from +// the whist.instances table of the database. It also deletes all rows from the +// whist.mandelboxes table that are foreign keyed to a whist.instances row whose +// deletion has been requested. +func (c *DBClient) TerminateLockedInstances(ctx context.Context, client subscriptions.WhistGraphQLClient, region string, ids []string) ([]string, error) { + var m subscriptions.TerminateLockedInstances + + // We need to pass the instance IDs as a slice of graphql String type + //instances, not just a normal string slice. + _ids := make([]graphql.String, 0, len(ids)) + + for _, id := range ids { + _ids = append(_ids, graphql.String(id)) + } + + vars := map[string]interface{}{ + "ids": _ids, + "region": graphql.String(region), + } + + if err := client.Mutate(ctx, &m, vars); err != nil { + return nil, err + } + + terminated := make([]string, 0, m.InstancesResponse.Count) + + for _, host := range m.InstancesResponse.Hosts { + terminated = append(terminated, string(host.ID)) + } + + return terminated, nil +} diff --git a/backend/services/scaling-service/dbclient/cleanup_test.go b/backend/services/scaling-service/dbclient/cleanup_test.go new file mode 100644 index 00000000000..1310c301ff5 --- /dev/null +++ b/backend/services/scaling-service/dbclient/cleanup_test.go @@ -0,0 +1,155 @@ +// Copyright (c) 2022 Whist Technologies, Inc. + +package dbclient + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "testing" + "time" + + "github.com/hasura/go-graphql-client" + "github.com/whisthq/whist/backend/services/subscriptions" +) + +var pkg DBClient + +// mockResponse generates a mock GraphQL mutation response. +func mockResponse(field string, ids []string) ([]byte, error) { + // The response contains a list of instance IDs. + type host struct { + ID string `json:"id"` + } + + count := len(ids) + hosts := make([]host, 0, count) + + for _, id := range ids { + hosts = append(hosts, host{id}) + } + + data := struct { + Count int `json:"affected_rows"` + Hosts []host `json:"returning"` + }{count, hosts} + + // Dynamically construct an auxiliary type to serialize the mock response + // data. + t := reflect.StructOf([]reflect.StructField{ + { + Name: "Response", + Type: reflect.TypeOf(data), + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s"`, field)), + }, + }) + v := reflect.New(t) + + v.Elem().FieldByName("Response").Set(reflect.ValueOf(data)) + + return json.Marshal(v.Interface()) +} + +// testClient provides mock responses to GraphQL mutations. +type testClient struct { + // ids contains mock instance IDs of unresponsive instances. + ids []string +} + +// Initialize is part of the subscriptions.WhistGraphQLClient interface. +func (*testClient) Initialize(bool) error { + return nil +} + +// Query is part of the subscriptions.WhistGraphQLClient interface. +func (*testClient) Query(context.Context, subscriptions.GraphQLQuery, map[string]interface{}) error { + return nil +} + +// Mutate is part of the subscriptions.WhistGraphQLClient interface. This +// implementation populates the mutation struct with mock host data. +func (c *testClient) Mutate(_ context.Context, q subscriptions.GraphQLQuery, v map[string]interface{}) error { + // Depending what kind of mock response we are providing, we select a list of + // instance IDs to return. + var ids []string + var field string + + switch q.(type) { + case *subscriptions.LockBrokenInstances: + field = "update_whist_instances" + + // We are providing a mock response to markForTermination, so we use the + // list of instance IDs stored in the "database" (i.e. struct) + ids = c.ids + case *subscriptions.TerminateLockedInstances: + field = "delete_whist_instances" + + // We are providing a mock response to finalizeTermination, so we use the + // value of the instance IDs input variable. + tmp1, ok := v["ids"] + + if !ok { + return errors.New("missing instance IDs variable") + } + + tmp2, ok := tmp1.([]graphql.String) + + if !ok { + return errors.New("instance IDs input variable should be a slice of " + + "graphql String types.") + } + + ids = make([]string, 0, len(tmp2)) + + for _, id := range tmp2 { + ids = append(ids, string(id)) + } + default: + t := reflect.TypeOf(q) + return fmt.Errorf("unrecognized mutation '%s'", t) + } + + data, err := mockResponse(field, ids) + + if err != nil { + return err + } + + // Deserialize the mock data into the response struct. + if err := graphql.UnmarshalGraphQL(data, q); err != nil { + return err + } + + return nil +} + +// TestLockBrokenInstances tests that LockBrokenInstances converts the GraphQL +// mutation response from Hasura to a slice of instance IDs. +func TestLockBrokenInstances(t *testing.T) { + var tt time.Time + ids := []string{"instance-0", "instance-1", "instance-2"} + client := &testClient{ids: ids} + res, err := pkg.LockBrokenInstances(context.TODO(), client, "us-east-1", tt) + + if err != nil { + t.Error("markForTermination:", err) + } else if !reflect.DeepEqual(res, ids) { + t.Error(fmt.Printf("Expected %v, got %v", ids, res)) + } +} + +// TestTerminateLockedInstances tests that TerminateLockedInstances converts +// the GraphQL mutation response from Hasura to a slice of instance IDs. +func TestTerminateLockedInstances(t *testing.T) { + client := &testClient{} + ids := []string{"instance-0", "instance-1", "instance-2"} + res, err := pkg.TerminateLockedInstances(context.TODO(), client, "us-east-1", ids) + + if err != nil { + t.Error("finalizeTermination:", err) + } else if !reflect.DeepEqual(res, ids) { + t.Error(fmt.Printf("Expected %v, got %v", ids, res)) + } +} diff --git a/backend/services/scaling-service/dbclient/client.go b/backend/services/scaling-service/dbclient/client.go index 34d954905c7..62d49c59b84 100644 --- a/backend/services/scaling-service/dbclient/client.go +++ b/backend/services/scaling-service/dbclient/client.go @@ -11,6 +11,7 @@ package dbclient import ( "context" + "time" "github.com/whisthq/whist/backend/services/subscriptions" ) @@ -37,6 +38,9 @@ type WhistDBClient interface { QueryUserMandelboxes(context.Context, subscriptions.WhistGraphQLClient, string) ([]subscriptions.Mandelbox, error) InsertMandelboxes(context.Context, subscriptions.WhistGraphQLClient, []subscriptions.Mandelbox) (int, error) UpdateMandelbox(context.Context, subscriptions.WhistGraphQLClient, subscriptions.Mandelbox) (int, error) + + LockBrokenInstances(context.Context, subscriptions.WhistGraphQLClient, string, time.Time) ([]string, error) + TerminateLockedInstances(context.Context, subscriptions.WhistGraphQLClient, string, []string) ([]string, error) } // DBClient implements `WhistDBClient`, it is the default database diff --git a/backend/services/scaling-service/event_handler.go b/backend/services/scaling-service/event_handler.go index af9216a5c86..1875c4106f0 100644 --- a/backend/services/scaling-service/event_handler.go +++ b/backend/services/scaling-service/event_handler.go @@ -28,6 +28,7 @@ package main import ( "context" "encoding/json" + "flag" "os" "os/signal" "strings" @@ -42,6 +43,7 @@ import ( "github.com/whisthq/whist/backend/services/metadata" "github.com/whisthq/whist/backend/services/scaling-service/config" "github.com/whisthq/whist/backend/services/scaling-service/dbclient" + hosts "github.com/whisthq/whist/backend/services/scaling-service/hosts/aws" algos "github.com/whisthq/whist/backend/services/scaling-service/scaling_algorithms/default" // Import as algos, short for scaling_algorithms "github.com/whisthq/whist/backend/services/subscriptions" "github.com/whisthq/whist/backend/services/utils" @@ -49,6 +51,17 @@ import ( ) func main() { + var ( + cleanupPeriod time.Duration + noCleanup bool + ) + + flag.DurationVar(&cleanupPeriod, "cleanup", time.Duration(time.Minute), + "the amount of time between when each cleanup thread runs") + flag.BoolVar(&noCleanup, "nocleanup", false, "disable asynchronous cleanup "+ + "of unresponsive instances") + flag.Parse() + globalCtx, globalCancel := context.WithCancel(context.Background()) goroutineTracker := &sync.WaitGroup{} @@ -117,28 +130,57 @@ func main() { // Use a sync map since we only write the keys once but will be reading multiple // times by different goroutines. algorithmByRegionMap := &sync.Map{} + regions := config.GetEnabledRegions() + stopFuncs := make([]func(), 0, len(regions)) + + // Load and instantiate default scaling algorithm for all enabled regions. + for _, region := range regions { + name := utils.Sprintf("default-sa-%s", region) + handler := &hosts.AWSHost{} + + if err := handler.Initialize(region); err != nil { + logger.Errorf("Failed to initialize host handler for region '%s'", region) + continue + } + + algo := &algos.DefaultScalingAlgorithm{Host: handler, Region: region} + + algo.CreateEventChans() + algo.CreateGraphQLClient(graphqlClient) + algo.CreateDBClient(dbClient) + algo.ProcessEvents(globalCtx, goroutineTracker) + + if noCleanup { + logger.Infof("Cleanup disabled. Not starting cleanup threads.") + } else { + stop := CleanRegion(graphqlClient, handler, cleanupPeriod) + stopFuncs = append(stopFuncs, stop) + logger.Infof("Unresponsive instances will be pruned every %s.", + cleanupPeriod) + } + + algorithmByRegionMap.Store(name, algo) - // Load default scaling algorithm for all enabled regions. - for _, region := range config.GetEnabledRegions() { logger.Infof("There should be as close as possible to %d unassigned "+ "Mandelboxes available at all times in %s", config.GetTargetFreeMandelboxes(region), region) - name := utils.Sprintf("default-sa-%s", region) - algorithmByRegionMap.Store(name, &algos.DefaultScalingAlgorithm{ - Region: region, - }) } - // Instantiate scaling algorithms on allowed regions - algorithmByRegionMap.Range(func(key, value interface{}) bool { - scalingAlgorithm := value.(algos.ScalingAlgorithm) - scalingAlgorithm.CreateEventChans() - scalingAlgorithm.CreateGraphQLClient(graphqlClient) - scalingAlgorithm.CreateDBClient(dbClient) - scalingAlgorithm.ProcessEvents(globalCtx, goroutineTracker) + // Wait for each of our cleanup threads to finish before we exit. + defer func() { + var wg sync.WaitGroup - return true - }) + for j := range stopFuncs { + wg.Add(1) + + go func(i int) { + defer wg.Done() + stopFuncs[i]() + }(j) + } + + wg.Wait() + }() // Start main event loop go eventLoop(globalCtx, globalCancel, serverEvents, subscriptionEvents, scheduledEvents, algorithmByRegionMap, configClient) diff --git a/backend/services/scaling-service/scaling_algorithms/default/actions_test.go b/backend/services/scaling-service/scaling_algorithms/default/actions_test.go index 2eee2be7e43..c64b293dcc2 100644 --- a/backend/services/scaling-service/scaling_algorithms/default/actions_test.go +++ b/backend/services/scaling-service/scaling_algorithms/default/actions_test.go @@ -5,6 +5,7 @@ package scaling_algorithms import ( "context" + "errors" "os" "sync" "testing" @@ -193,6 +194,14 @@ func (db *mockDBClient) UpdateMandelbox(_ context.Context, _ subscriptions.Whist return affected, nil } +func (db *mockDBClient) LockBrokenInstances(context.Context, subscriptions.WhistGraphQLClient, string, time.Time) ([]string, error) { + return nil, errors.New("Not implemented") +} + +func (db *mockDBClient) TerminateLockedInstances(context.Context, subscriptions.WhistGraphQLClient, string, []string) ([]string, error) { + return nil, errors.New("Not implemented") +} + // mockHostHandler is used to test all interactions with cloud providers type mockHostHandler struct{} diff --git a/backend/services/scaling-service/scaling_algorithms/default/default.go b/backend/services/scaling-service/scaling_algorithms/default/default.go index b5ba311f97e..b7bd2bae676 100644 --- a/backend/services/scaling-service/scaling_algorithms/default/default.go +++ b/backend/services/scaling-service/scaling_algorithms/default/default.go @@ -27,7 +27,6 @@ import ( "github.com/whisthq/whist/backend/services/scaling-service/dbclient" "github.com/whisthq/whist/backend/services/scaling-service/hosts" - aws "github.com/whisthq/whist/backend/services/scaling-service/hosts/aws" "github.com/whisthq/whist/backend/services/subscriptions" "github.com/whisthq/whist/backend/services/utils" logger "github.com/whisthq/whist/backend/services/whistlogger" @@ -117,19 +116,6 @@ func (s *DefaultScalingAlgorithm) CreateDBClient(dbClient dbclient.WhistDBClient // events and executing the appropiate scaling actions. This function is specific for each region // scaling algorithm to be able to implement different strategies on each region. func (s *DefaultScalingAlgorithm) ProcessEvents(globalCtx context.Context, goroutineTracker *sync.WaitGroup) { - if s.Host == nil { - // TODO when multi-cloud support is introduced, figure out a way to - // decide which host to use. For now default to AWS. - handler := &aws.AWSHost{} - err := handler.Initialize(s.Region) - - if err != nil { - logger.Errorf("error starting host in %s: %s", s.Region, err) - } - - s.Host = handler - } - // Start algorithm main event loop // Track this goroutine so we can wait for it to // finish if the global context gets cancelled. diff --git a/backend/services/subscriptions/mutations.go b/backend/services/subscriptions/mutations.go index d17dad7e864..243e84ba690 100644 --- a/backend/services/subscriptions/mutations.go +++ b/backend/services/subscriptions/mutations.go @@ -2,6 +2,43 @@ package subscriptions import "github.com/hasura/go-graphql-client" +type ( + + // LockBrokenInstances takes two arguments: region and maxAge. The region + // argument matches values from the region column of the whist.instances + // table. The maxAge argument is a timestamptz that specifies the time after + // which, if an instance has not updated itself in the database, it will be + // considered broken. This query gives the status column of all rows of the + // database that have an updated_at timestamp that is older than maxAge the + // TERMINATING status. It returns the instance_ids of the affected rows. + LockBrokenInstances struct { + Response struct { + Count graphql.Int `graphql:"affected_rows"` + Hosts []struct { + ID graphql.String `graphql:"id"` + } `graphql:"returning"` + } `graphql:"update_whist_instances(where: {region: {_eq: $region}, updated_at: {_lt: $maxAge}, status: {_in: [\"ACTIVE\", \"TERMINATING\"]}}, _set: {status: \"TERMINATING\"})"` + } + + // TerminateLockedInstances takes two arguments: region and ids. The region + // argument matches values from the region column of the whist.instances + // table. The ids argument specifies the instance_ids of the instances to + // terminate. All of these instances should already be locked, i.e. the status + // column of each of their rows in the whist.instances database should have + // TERMINATING. + TerminateLockedInstances struct { + MandelboxesResponse struct { + Count graphql.Int `graphql:"affected_rows"` + } `graphql:"delete_whist_mandelboxes(where: {instance: {region: {_eq: $region}, id: {_in: $ids}, status: {_eq: \"TERMINATING\"}}})"` + InstancesResponse struct { + Count graphql.Int `graphql:"affected_rows"` + Hosts []struct { + ID graphql.String `graphql:"id"` + } `graphql:"returning"` + } `graphql:"delete_whist_instances(where: {region: {_eq: $region}, id: {_in: $ids}, status: {_eq: \"TERMINATING\"}})"` + } +) + var ( // InsertInstances inserts multiple instances to the database.