diff --git a/cmd/manager/Dockerfile b/cmd/manager/Dockerfile index a25a0310d07..67197a08f33 100644 --- a/cmd/manager/Dockerfile +++ b/cmd/manager/Dockerfile @@ -6,6 +6,9 @@ RUN go build -o vizier-manager FROM alpine:3.7 WORKDIR /app +RUN GRPC_HEALTH_PROBE_VERSION=v0.2.0 && \ + wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64 && \ + chmod +x /bin/grpc_health_probe COPY --from=build-env /go/src/github.com/kubeflow/katib/cmd/manager/vizier-manager /app/ COPY --from=build-env /go/src/github.com/kubeflow/katib/pkg/manager/visualise / ENTRYPOINT ["./vizier-manager"] diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 11d4da50e62..fe7d388f091 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -4,11 +4,13 @@ import ( "context" "errors" "flag" + "fmt" "log" "net" "time" - pb "github.com/kubeflow/katib/pkg/api" + api_pb "github.com/kubeflow/katib/pkg/api" + health_pb "github.com/kubeflow/katib/pkg/api/health" kdb "github.com/kubeflow/katib/pkg/db" "github.com/kubeflow/katib/pkg/manager/modelstore" @@ -26,114 +28,114 @@ type server struct { msIf modelstore.ModelStore } -func (s *server) CreateStudy(ctx context.Context, in *pb.CreateStudyRequest) (*pb.CreateStudyReply, error) { +func (s *server) CreateStudy(ctx context.Context, in *api_pb.CreateStudyRequest) (*api_pb.CreateStudyReply, error) { if in == nil || in.StudyConfig == nil { - return &pb.CreateStudyReply{}, errors.New("StudyConfig is missing.") + return &api_pb.CreateStudyReply{}, errors.New("StudyConfig is missing.") } studyID, err := dbIf.CreateStudy(in.StudyConfig) if err != nil { - return &pb.CreateStudyReply{}, err + return &api_pb.CreateStudyReply{}, err } - s.SaveStudy(ctx, &pb.SaveStudyRequest{ + s.SaveStudy(ctx, &api_pb.SaveStudyRequest{ StudyName: in.StudyConfig.Name, Owner: in.StudyConfig.Owner, Description: "StudyID: " + studyID, }) - return &pb.CreateStudyReply{StudyId: studyID}, nil + return &api_pb.CreateStudyReply{StudyId: studyID}, nil } -func (s *server) GetStudy(ctx context.Context, in *pb.GetStudyRequest) (*pb.GetStudyReply, error) { +func (s *server) GetStudy(ctx context.Context, in *api_pb.GetStudyRequest) (*api_pb.GetStudyReply, error) { sc, err := dbIf.GetStudyConfig(in.StudyId) - return &pb.GetStudyReply{StudyConfig: sc}, err + return &api_pb.GetStudyReply{StudyConfig: sc}, err } -func (s *server) GetStudyList(ctx context.Context, in *pb.GetStudyListRequest) (*pb.GetStudyListReply, error) { +func (s *server) GetStudyList(ctx context.Context, in *api_pb.GetStudyListRequest) (*api_pb.GetStudyListReply, error) { sl, err := dbIf.GetStudyList() if err != nil { - return &pb.GetStudyListReply{}, err + return &api_pb.GetStudyListReply{}, err } - result := make([]*pb.StudyOverview, len(sl)) + result := make([]*api_pb.StudyOverview, len(sl)) for i, id := range sl { sc, err := dbIf.GetStudyConfig(id) if err != nil { - return &pb.GetStudyListReply{}, err + return &api_pb.GetStudyListReply{}, err } - result[i] = &pb.StudyOverview{ + result[i] = &api_pb.StudyOverview{ Name: sc.Name, Owner: sc.Owner, Id: id, } } - return &pb.GetStudyListReply{StudyOverviews: result}, err + return &api_pb.GetStudyListReply{StudyOverviews: result}, err } -func (s *server) CreateTrial(ctx context.Context, in *pb.CreateTrialRequest) (*pb.CreateTrialReply, error) { +func (s *server) CreateTrial(ctx context.Context, in *api_pb.CreateTrialRequest) (*api_pb.CreateTrialReply, error) { err := dbIf.CreateTrial(in.Trial) - return &pb.CreateTrialReply{TrialId: in.Trial.TrialId}, err + return &api_pb.CreateTrialReply{TrialId: in.Trial.TrialId}, err } -func (s *server) GetTrials(ctx context.Context, in *pb.GetTrialsRequest) (*pb.GetTrialsReply, error) { +func (s *server) GetTrials(ctx context.Context, in *api_pb.GetTrialsRequest) (*api_pb.GetTrialsReply, error) { tl, err := dbIf.GetTrialList(in.StudyId) - return &pb.GetTrialsReply{Trials: tl}, err + return &api_pb.GetTrialsReply{Trials: tl}, err } -func (s *server) GetSuggestions(ctx context.Context, in *pb.GetSuggestionsRequest) (*pb.GetSuggestionsReply, error) { +func (s *server) GetSuggestions(ctx context.Context, in *api_pb.GetSuggestionsRequest) (*api_pb.GetSuggestionsReply, error) { if in.SuggestionAlgorithm == "" { - return &pb.GetSuggestionsReply{Trials: []*pb.Trial{}}, errors.New("No suggest algorithm specified") + return &api_pb.GetSuggestionsReply{Trials: []*api_pb.Trial{}}, errors.New("No suggest algorithm specified") } conn, err := grpc.Dial("vizier-suggestion-"+in.SuggestionAlgorithm+":6789", grpc.WithInsecure()) if err != nil { - return &pb.GetSuggestionsReply{Trials: []*pb.Trial{}}, err + return &api_pb.GetSuggestionsReply{Trials: []*api_pb.Trial{}}, err } defer conn.Close() - c := pb.NewSuggestionClient(conn) + c := api_pb.NewSuggestionClient(conn) r, err := c.GetSuggestions(ctx, in) if err != nil { - return &pb.GetSuggestionsReply{Trials: []*pb.Trial{}}, err + return &api_pb.GetSuggestionsReply{Trials: []*api_pb.Trial{}}, err } return r, nil } -func (s *server) RegisterWorker(ctx context.Context, in *pb.RegisterWorkerRequest) (*pb.RegisterWorkerReply, error) { +func (s *server) RegisterWorker(ctx context.Context, in *api_pb.RegisterWorkerRequest) (*api_pb.RegisterWorkerReply, error) { wid, err := dbIf.CreateWorker(in.Worker) - return &pb.RegisterWorkerReply{WorkerId: wid}, err + return &api_pb.RegisterWorkerReply{WorkerId: wid}, err } -func (s *server) GetWorkers(ctx context.Context, in *pb.GetWorkersRequest) (*pb.GetWorkersReply, error) { - var ws []*pb.Worker +func (s *server) GetWorkers(ctx context.Context, in *api_pb.GetWorkersRequest) (*api_pb.GetWorkersReply, error) { + var ws []*api_pb.Worker var err error if in.WorkerId == "" { ws, err = dbIf.GetWorkerList(in.StudyId, in.TrialId) } else { - var w *pb.Worker + var w *api_pb.Worker w, err = dbIf.GetWorker(in.WorkerId) ws = append(ws, w) } - return &pb.GetWorkersReply{Workers: ws}, err + return &api_pb.GetWorkersReply{Workers: ws}, err } -func (s *server) GetShouldStopWorkers(ctx context.Context, in *pb.GetShouldStopWorkersRequest) (*pb.GetShouldStopWorkersReply, error) { +func (s *server) GetShouldStopWorkers(ctx context.Context, in *api_pb.GetShouldStopWorkersRequest) (*api_pb.GetShouldStopWorkersReply, error) { if in.EarlyStoppingAlgorithm == "" { - return &pb.GetShouldStopWorkersReply{}, errors.New("No EarlyStopping Algorithm specified") + return &api_pb.GetShouldStopWorkersReply{}, errors.New("No EarlyStopping Algorithm specified") } conn, err := grpc.Dial("vizier-earlystopping-"+in.EarlyStoppingAlgorithm+":6789", grpc.WithInsecure()) if err != nil { - return &pb.GetShouldStopWorkersReply{}, err + return &api_pb.GetShouldStopWorkersReply{}, err } defer conn.Close() - c := pb.NewEarlyStoppingClient(conn) + c := api_pb.NewEarlyStoppingClient(conn) return c.GetShouldStopWorkers(context.Background(), in) } -func (s *server) GetMetrics(ctx context.Context, in *pb.GetMetricsRequest) (*pb.GetMetricsReply, error) { +func (s *server) GetMetrics(ctx context.Context, in *api_pb.GetMetricsRequest) (*api_pb.GetMetricsReply, error) { var mNames []string if in.StudyId == "" { - return &pb.GetMetricsReply{}, errors.New("StudyId should be set") + return &api_pb.GetMetricsReply{}, errors.New("StudyId should be set") } sc, err := dbIf.GetStudyConfig(in.StudyId) if err != nil { - return &pb.GetMetricsReply{}, err + return &api_pb.GetMetricsReply{}, err } if len(in.MetricsNames) > 0 { mNames = in.MetricsNames @@ -141,62 +143,62 @@ func (s *server) GetMetrics(ctx context.Context, in *pb.GetMetricsRequest) (*pb. mNames = sc.Metrics } if err != nil { - return &pb.GetMetricsReply{}, err + return &api_pb.GetMetricsReply{}, err } - mls := make([]*pb.MetricsLogSet, len(in.WorkerIds)) + mls := make([]*api_pb.MetricsLogSet, len(in.WorkerIds)) for i, w := range in.WorkerIds { - wr, err := s.GetWorkers(ctx, &pb.GetWorkersRequest{ + wr, err := s.GetWorkers(ctx, &api_pb.GetWorkersRequest{ StudyId: in.StudyId, WorkerId: w, }) if err != nil { - return &pb.GetMetricsReply{}, err + return &api_pb.GetMetricsReply{}, err } - mls[i] = &pb.MetricsLogSet{ + mls[i] = &api_pb.MetricsLogSet{ WorkerId: w, - MetricsLogs: make([]*pb.MetricsLog, len(mNames)), + MetricsLogs: make([]*api_pb.MetricsLog, len(mNames)), WorkerStatus: wr.Workers[0].Status, } for j, m := range mNames { ls, err := dbIf.GetWorkerLogs(w, &kdb.GetWorkerLogOpts{Name: m}) if err != nil { - return &pb.GetMetricsReply{}, err + return &api_pb.GetMetricsReply{}, err } - mls[i].MetricsLogs[j] = &pb.MetricsLog{ + mls[i].MetricsLogs[j] = &api_pb.MetricsLog{ Name: m, - Values: make([]*pb.MetricsValueTime, len(ls)), + Values: make([]*api_pb.MetricsValueTime, len(ls)), } for k, l := range ls { - mls[i].MetricsLogs[j].Values[k] = &pb.MetricsValueTime{ + mls[i].MetricsLogs[j].Values[k] = &api_pb.MetricsValueTime{ Value: l.Value, Time: l.Time.UTC().Format(time.RFC3339Nano), } } } } - return &pb.GetMetricsReply{MetricsLogSets: mls}, nil + return &api_pb.GetMetricsReply{MetricsLogSets: mls}, nil } -func (s *server) ReportMetricsLogs(ctx context.Context, in *pb.ReportMetricsLogsRequest) (*pb.ReportMetricsLogsReply, error) { +func (s *server) ReportMetricsLogs(ctx context.Context, in *api_pb.ReportMetricsLogsRequest) (*api_pb.ReportMetricsLogsReply, error) { for _, mls := range in.MetricsLogSets { err := dbIf.StoreWorkerLogs(mls.WorkerId, mls.MetricsLogs) if err != nil { - return &pb.ReportMetricsLogsReply{}, err + return &api_pb.ReportMetricsLogsReply{}, err } } - return &pb.ReportMetricsLogsReply{}, nil + return &api_pb.ReportMetricsLogsReply{}, nil } -func (s *server) UpdateWorkerState(ctx context.Context, in *pb.UpdateWorkerStateRequest) (*pb.UpdateWorkerStateReply, error) { +func (s *server) UpdateWorkerState(ctx context.Context, in *api_pb.UpdateWorkerStateRequest) (*api_pb.UpdateWorkerStateReply, error) { err := dbIf.UpdateWorker(in.WorkerId, in.Status) - return &pb.UpdateWorkerStateReply{}, err + return &api_pb.UpdateWorkerStateReply{}, err } -func (s *server) GetWorkerFullInfo(ctx context.Context, in *pb.GetWorkerFullInfoRequest) (*pb.GetWorkerFullInfoReply, error) { +func (s *server) GetWorkerFullInfo(ctx context.Context, in *api_pb.GetWorkerFullInfoRequest) (*api_pb.GetWorkerFullInfoReply, error) { return dbIf.GetWorkerFullInfo(in.StudyId, in.TrialId, in.WorkerId, in.OnlyLatestLog) } -func (s *server) SetSuggestionParameters(ctx context.Context, in *pb.SetSuggestionParametersRequest) (*pb.SetSuggestionParametersReply, error) { +func (s *server) SetSuggestionParameters(ctx context.Context, in *api_pb.SetSuggestionParametersRequest) (*api_pb.SetSuggestionParametersReply, error) { var err error var id string if in.ParamId == "" { @@ -205,10 +207,10 @@ func (s *server) SetSuggestionParameters(ctx context.Context, in *pb.SetSuggesti id = in.ParamId err = dbIf.UpdateSuggestionParam(in.ParamId, in.SuggestionParameters) } - return &pb.SetSuggestionParametersReply{ParamId: id}, err + return &api_pb.SetSuggestionParametersReply{ParamId: id}, err } -func (s *server) SetEarlyStoppingParameters(ctx context.Context, in *pb.SetEarlyStoppingParametersRequest) (*pb.SetEarlyStoppingParametersReply, error) { +func (s *server) SetEarlyStoppingParameters(ctx context.Context, in *api_pb.SetEarlyStoppingParametersRequest) (*api_pb.SetEarlyStoppingParametersReply, error) { var err error var id string if in.ParamId == "" { @@ -217,88 +219,115 @@ func (s *server) SetEarlyStoppingParameters(ctx context.Context, in *pb.SetEarly id = in.ParamId err = dbIf.UpdateEarlyStopParam(in.ParamId, in.EarlyStoppingParameters) } - return &pb.SetEarlyStoppingParametersReply{ParamId: id}, err + return &api_pb.SetEarlyStoppingParametersReply{ParamId: id}, err } -func (s *server) GetSuggestionParameters(ctx context.Context, in *pb.GetSuggestionParametersRequest) (*pb.GetSuggestionParametersReply, error) { +func (s *server) GetSuggestionParameters(ctx context.Context, in *api_pb.GetSuggestionParametersRequest) (*api_pb.GetSuggestionParametersReply, error) { ps, err := dbIf.GetSuggestionParam(in.ParamId) - return &pb.GetSuggestionParametersReply{SuggestionParameters: ps}, err + return &api_pb.GetSuggestionParametersReply{SuggestionParameters: ps}, err } -func (s *server) GetSuggestionParameterList(ctx context.Context, in *pb.GetSuggestionParameterListRequest) (*pb.GetSuggestionParameterListReply, error) { +func (s *server) GetSuggestionParameterList(ctx context.Context, in *api_pb.GetSuggestionParameterListRequest) (*api_pb.GetSuggestionParameterListReply, error) { pss, err := dbIf.GetSuggestionParamList(in.StudyId) - return &pb.GetSuggestionParameterListReply{SuggestionParameterSets: pss}, err + return &api_pb.GetSuggestionParameterListReply{SuggestionParameterSets: pss}, err } -func (s *server) GetEarlyStoppingParameters(ctx context.Context, in *pb.GetEarlyStoppingParametersRequest) (*pb.GetEarlyStoppingParametersReply, error) { +func (s *server) GetEarlyStoppingParameters(ctx context.Context, in *api_pb.GetEarlyStoppingParametersRequest) (*api_pb.GetEarlyStoppingParametersReply, error) { ps, err := dbIf.GetEarlyStopParam(in.ParamId) - return &pb.GetEarlyStoppingParametersReply{EarlyStoppingParameters: ps}, err + return &api_pb.GetEarlyStoppingParametersReply{EarlyStoppingParameters: ps}, err } -func (s *server) GetEarlyStoppingParameterList(ctx context.Context, in *pb.GetEarlyStoppingParameterListRequest) (*pb.GetEarlyStoppingParameterListReply, error) { +func (s *server) GetEarlyStoppingParameterList(ctx context.Context, in *api_pb.GetEarlyStoppingParameterListRequest) (*api_pb.GetEarlyStoppingParameterListReply, error) { pss, err := dbIf.GetEarlyStopParamList(in.StudyId) - return &pb.GetEarlyStoppingParameterListReply{EarlyStoppingParameterSets: pss}, err + return &api_pb.GetEarlyStoppingParameterListReply{EarlyStoppingParameterSets: pss}, err } -func (s *server) SaveStudy(ctx context.Context, in *pb.SaveStudyRequest) (*pb.SaveStudyReply, error) { +func (s *server) SaveStudy(ctx context.Context, in *api_pb.SaveStudyRequest) (*api_pb.SaveStudyReply, error) { var err error if s.msIf != nil { err = s.msIf.SaveStudy(in) } - return &pb.SaveStudyReply{}, err + return &api_pb.SaveStudyReply{}, err } -func (s *server) SaveModel(ctx context.Context, in *pb.SaveModelRequest) (*pb.SaveModelReply, error) { +func (s *server) SaveModel(ctx context.Context, in *api_pb.SaveModelRequest) (*api_pb.SaveModelReply, error) { if s.msIf != nil { err := s.msIf.SaveModel(in) if err != nil { log.Printf("Save Model failed %v", err) - return &pb.SaveModelReply{}, err + return &api_pb.SaveModelReply{}, err } } - return &pb.SaveModelReply{}, nil + return &api_pb.SaveModelReply{}, nil } -func (s *server) GetSavedStudies(ctx context.Context, in *pb.GetSavedStudiesRequest) (*pb.GetSavedStudiesReply, error) { - ret := []*pb.StudyOverview{} +func (s *server) GetSavedStudies(ctx context.Context, in *api_pb.GetSavedStudiesRequest) (*api_pb.GetSavedStudiesReply, error) { + ret := []*api_pb.StudyOverview{} var err error if s.msIf != nil { ret, err = s.msIf.GetSavedStudies() } - return &pb.GetSavedStudiesReply{Studies: ret}, err + return &api_pb.GetSavedStudiesReply{Studies: ret}, err } -func (s *server) GetSavedModels(ctx context.Context, in *pb.GetSavedModelsRequest) (*pb.GetSavedModelsReply, error) { - ret := []*pb.ModelInfo{} +func (s *server) GetSavedModels(ctx context.Context, in *api_pb.GetSavedModelsRequest) (*api_pb.GetSavedModelsReply, error) { + ret := []*api_pb.ModelInfo{} var err error if s.msIf != nil { ret, err = s.msIf.GetSavedModels(in) } - return &pb.GetSavedModelsReply{Models: ret}, err + return &api_pb.GetSavedModelsReply{Models: ret}, err } -func (s *server) GetSavedModel(ctx context.Context, in *pb.GetSavedModelRequest) (*pb.GetSavedModelReply, error) { - var ret *pb.ModelInfo = nil +func (s *server) GetSavedModel(ctx context.Context, in *api_pb.GetSavedModelRequest) (*api_pb.GetSavedModelReply, error) { + var ret *api_pb.ModelInfo = nil var err error if s.msIf != nil { ret, err = s.msIf.GetSavedModel(in) } - return &pb.GetSavedModelReply{Model: ret}, err + return &api_pb.GetSavedModelReply{Model: ret}, err +} + +func (s *server) Check(ctx context.Context, in *health_pb.HealthCheckRequest) (*health_pb.HealthCheckResponse, error) { + resp := health_pb.HealthCheckResponse{ + Status: health_pb.HealthCheckResponse_SERVING, + } + + // We only accept optional service name only if it's set to suggested format. + if in != nil && in.Service != "" && in.Service != "grpc.health.v1.Health" { + resp.Status = health_pb.HealthCheckResponse_UNKNOWN + return &resp, fmt.Errorf("grpc.health.v1.Health can only be accepted if you specify service name.") + } + + // Check if connection to vizier-db is okay since otherwise manager could not serve most of its methods. + err := dbIf.SelectOne() + if err != nil { + resp.Status = health_pb.HealthCheckResponse_NOT_SERVING + return &resp, fmt.Errorf("Failed to execute `SELECT 1` probe: %v", err) + } + + return &resp, nil } func main() { flag.Parse() var err error - dbIf = kdb.New() + + dbIf, err = kdb.New() + if err != nil { + log.Fatalf("Failed to open db connection: %v", err) + } dbIf.DBInit() listener, err := net.Listen("tcp", port) if err != nil { log.Fatalf("Failed to listen: %v", err) } + size := 1<<31 - 1 log.Printf("Start Katib manager: %s", port) s := grpc.NewServer(grpc.MaxRecvMsgSize(size), grpc.MaxSendMsgSize(size)) - pb.RegisterManagerServer(s, &server{}) + api_pb.RegisterManagerServer(s, &server{}) + health_pb.RegisterHealthServer(s, &server{}) reflection.Register(s) if err = s.Serve(listener); err != nil { log.Fatalf("Failed to serve: %v", err) diff --git a/manifests/vizier/core/deployment.yaml b/manifests/vizier/core/deployment.yaml index 1a50b8997b1..f42c4b84ed2 100644 --- a/manifests/vizier/core/deployment.yaml +++ b/manifests/vizier/core/deployment.yaml @@ -30,6 +30,14 @@ spec: ports: - name: api containerPort: 6789 + readinessProbe: + exec: + command: ["/bin/grpc_health_probe", "-addr=:6789"] + initialDelaySeconds: 5 + livenessProbe: + exec: + command: ["/bin/grpc_health_probe", "-addr=:6789"] + initialDelaySeconds: 10 # resources: # requests: # cpu: 500m diff --git a/pkg/api/Makefile b/pkg/api/Makefile index f63e1d5a8f4..8824c274827 100644 --- a/pkg/api/Makefile +++ b/pkg/api/Makefile @@ -1,2 +1,4 @@ api.pb.go: api.proto protoc -I. api.proto --go_out=plugins=grpc:. +health.pb.go: health/health.proto + protoc -I. health/health.proto --go_out=plugins=grpc:. diff --git a/pkg/api/build.sh b/pkg/api/build.sh index 380565cc4c2..bf0e8419ca4 100755 --- a/pkg/api/build.sh +++ b/pkg/api/build.sh @@ -1,4 +1,6 @@ -docker run -it --rm -v $PWD:$(pwd) -w $(pwd) znly/protoc --python_out=plugins=grpc:./python --go_out=plugins=grpc:. -I. api.proto -docker run -it --rm -v $PWD:$(pwd) -w $(pwd) znly/protoc --plugin=protoc-gen-grpc=/usr/bin/grpc_python_plugin --python_out=./python --grpc_out=./python -I. api.proto -docker run -it --rm -v $PWD:$(pwd) -w $(pwd) znly/protoc --grpc-gateway_out=logtostderr=true:. -I. api.proto -docker run -it --rm -v $PWD:$(pwd) -w $(pwd) znly/protoc --swagger_out=logtostderr=true:. -I. api.proto +for proto in api.proto health/health.proto; do + docker run -it --rm -v $PWD:$(pwd) -w $(pwd) znly/protoc --python_out=plugins=grpc:./python --go_out=plugins=grpc:. -I. $proto + docker run -it --rm -v $PWD:$(pwd) -w $(pwd) znly/protoc --plugin=protoc-gen-grpc=/usr/bin/grpc_python_plugin --python_out=./python --grpc_out=./python -I. $proto + docker run -it --rm -v $PWD:$(pwd) -w $(pwd) znly/protoc --grpc-gateway_out=logtostderr=true:. -I. $proto + docker run -it --rm -v $PWD:$(pwd) -w $(pwd) znly/protoc --swagger_out=logtostderr=true:. -I. $proto +done diff --git a/pkg/api/health/health.pb.go b/pkg/api/health/health.pb.go new file mode 100644 index 00000000000..342a9f57f2c --- /dev/null +++ b/pkg/api/health/health.pb.go @@ -0,0 +1,189 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: health/health.proto + +/* +Package grpc_health_v1 is a generated protocol buffer package. + +It is generated from these files: + health/health.proto + +It has these top-level messages: + HealthCheckRequest + HealthCheckResponse +*/ +package grpc_health_v1 + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type HealthCheckResponse_ServingStatus int32 + +const ( + HealthCheckResponse_UNKNOWN HealthCheckResponse_ServingStatus = 0 + HealthCheckResponse_SERVING HealthCheckResponse_ServingStatus = 1 + HealthCheckResponse_NOT_SERVING HealthCheckResponse_ServingStatus = 2 +) + +var HealthCheckResponse_ServingStatus_name = map[int32]string{ + 0: "UNKNOWN", + 1: "SERVING", + 2: "NOT_SERVING", +} +var HealthCheckResponse_ServingStatus_value = map[string]int32{ + "UNKNOWN": 0, + "SERVING": 1, + "NOT_SERVING": 2, +} + +func (x HealthCheckResponse_ServingStatus) String() string { + return proto.EnumName(HealthCheckResponse_ServingStatus_name, int32(x)) +} +func (HealthCheckResponse_ServingStatus) EnumDescriptor() ([]byte, []int) { + return fileDescriptor0, []int{1, 0} +} + +type HealthCheckRequest struct { + Service string `protobuf:"bytes,1,opt,name=service" json:"service,omitempty"` +} + +func (m *HealthCheckRequest) Reset() { *m = HealthCheckRequest{} } +func (m *HealthCheckRequest) String() string { return proto.CompactTextString(m) } +func (*HealthCheckRequest) ProtoMessage() {} +func (*HealthCheckRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *HealthCheckRequest) GetService() string { + if m != nil { + return m.Service + } + return "" +} + +type HealthCheckResponse struct { + Status HealthCheckResponse_ServingStatus `protobuf:"varint,1,opt,name=status,enum=grpc.health.v1.HealthCheckResponse_ServingStatus" json:"status,omitempty"` +} + +func (m *HealthCheckResponse) Reset() { *m = HealthCheckResponse{} } +func (m *HealthCheckResponse) String() string { return proto.CompactTextString(m) } +func (*HealthCheckResponse) ProtoMessage() {} +func (*HealthCheckResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *HealthCheckResponse) GetStatus() HealthCheckResponse_ServingStatus { + if m != nil { + return m.Status + } + return HealthCheckResponse_UNKNOWN +} + +func init() { + proto.RegisterType((*HealthCheckRequest)(nil), "grpc.health.v1.HealthCheckRequest") + proto.RegisterType((*HealthCheckResponse)(nil), "grpc.health.v1.HealthCheckResponse") + proto.RegisterEnum("grpc.health.v1.HealthCheckResponse_ServingStatus", HealthCheckResponse_ServingStatus_name, HealthCheckResponse_ServingStatus_value) +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for Health service + +type HealthClient interface { + Check(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) +} + +type healthClient struct { + cc *grpc.ClientConn +} + +func NewHealthClient(cc *grpc.ClientConn) HealthClient { + return &healthClient{cc} +} + +func (c *healthClient) Check(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) { + out := new(HealthCheckResponse) + err := grpc.Invoke(ctx, "/grpc.health.v1.Health/Check", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for Health service + +type HealthServer interface { + Check(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) +} + +func RegisterHealthServer(s *grpc.Server, srv HealthServer) { + s.RegisterService(&_Health_serviceDesc, srv) +} + +func _Health_Check_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(HealthCheckRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HealthServer).Check(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/grpc.health.v1.Health/Check", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HealthServer).Check(ctx, req.(*HealthCheckRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _Health_serviceDesc = grpc.ServiceDesc{ + ServiceName: "grpc.health.v1.Health", + HandlerType: (*HealthServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Check", + Handler: _Health_Check_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "health/health.proto", +} + +func init() { proto.RegisterFile("health/health.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 207 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0xce, 0x48, 0x4d, 0xcc, + 0x29, 0xc9, 0xd0, 0x87, 0x50, 0x7a, 0x05, 0x45, 0xf9, 0x25, 0xf9, 0x42, 0x7c, 0xe9, 0x45, 0x05, + 0xc9, 0x7a, 0x50, 0xa1, 0x32, 0x43, 0x25, 0x3d, 0x2e, 0x21, 0x0f, 0x30, 0xc7, 0x39, 0x23, 0x35, + 0x39, 0x3b, 0x28, 0xb5, 0xb0, 0x34, 0xb5, 0xb8, 0x44, 0x48, 0x82, 0x8b, 0xbd, 0x38, 0xb5, 0xa8, + 0x2c, 0x33, 0x39, 0x55, 0x82, 0x51, 0x81, 0x51, 0x83, 0x33, 0x08, 0xc6, 0x55, 0x9a, 0xc3, 0xc8, + 0x25, 0x8c, 0xa2, 0xa1, 0xb8, 0x20, 0x3f, 0xaf, 0x38, 0x55, 0xc8, 0x93, 0x8b, 0xad, 0xb8, 0x24, + 0xb1, 0xa4, 0xb4, 0x18, 0xac, 0x81, 0xcf, 0xc8, 0x50, 0x0f, 0xd5, 0x22, 0x3d, 0x2c, 0x9a, 0xf4, + 0x82, 0x41, 0x86, 0xe6, 0xa5, 0x07, 0x83, 0x35, 0x06, 0x41, 0x0d, 0x50, 0xb2, 0xe2, 0xe2, 0x45, + 0x91, 0x10, 0xe2, 0xe6, 0x62, 0x0f, 0xf5, 0xf3, 0xf6, 0xf3, 0x0f, 0xf7, 0x13, 0x60, 0x00, 0x71, + 0x82, 0x5d, 0x83, 0xc2, 0x3c, 0xfd, 0xdc, 0x05, 0x18, 0x85, 0xf8, 0xb9, 0xb8, 0xfd, 0xfc, 0x43, + 0xe2, 0x61, 0x02, 0x4c, 0x46, 0x51, 0x5c, 0x6c, 0x10, 0x8b, 0x84, 0x02, 0xb8, 0x58, 0xc1, 0x96, + 0x09, 0x29, 0xe1, 0x75, 0x09, 0xd8, 0xbf, 0x52, 0xca, 0x44, 0xb8, 0x36, 0x89, 0x0d, 0x1c, 0x82, + 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0xb4, 0x1d, 0x77, 0xe9, 0x58, 0x01, 0x00, 0x00, +} diff --git a/pkg/api/health/health.proto b/pkg/api/health/health.proto new file mode 100644 index 00000000000..97d08baf9b5 --- /dev/null +++ b/pkg/api/health/health.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package grpc.health.v1; + +service Health { + rpc Check(HealthCheckRequest) returns (HealthCheckResponse); +} + +message HealthCheckRequest { + string service = 1; +} + +message HealthCheckResponse { + enum ServingStatus { + UNKNOWN = 0; + SERVING = 1; + NOT_SERVING = 2; + } + ServingStatus status = 1; +} diff --git a/pkg/api/health/health.swagger.json b/pkg/api/health/health.swagger.json new file mode 100644 index 00000000000..9619519fbe9 --- /dev/null +++ b/pkg/api/health/health.swagger.json @@ -0,0 +1,37 @@ +{ + "swagger": "2.0", + "info": { + "title": "health/health.proto", + "version": "version not set" + }, + "schemes": [ + "http", + "https" + ], + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "paths": {}, + "definitions": { + "HealthCheckResponseServingStatus": { + "type": "string", + "enum": [ + "UNKNOWN", + "SERVING", + "NOT_SERVING" + ], + "default": "UNKNOWN" + }, + "v1HealthCheckResponse": { + "type": "object", + "properties": { + "status": { + "$ref": "#/definitions/HealthCheckResponseServingStatus" + } + } + } + } +} diff --git a/pkg/api/python/health/health_pb2.py b/pkg/api/python/health/health_pb2.py new file mode 100644 index 00000000000..be2ab95f31a --- /dev/null +++ b/pkg/api/python/health/health_pb2.py @@ -0,0 +1,280 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: health/health.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='health/health.proto', + package='grpc.health.v1', + syntax='proto3', + serialized_pb=_b('\n\x13health/health.proto\x12\x0egrpc.health.v1\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\x94\x01\n\x13HealthCheckResponse\x12\x41\n\x06status\x18\x01 \x01(\x0e\x32\x31.grpc.health.v1.HealthCheckResponse.ServingStatus\":\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x32Z\n\x06Health\x12P\n\x05\x43heck\x12\".grpc.health.v1.HealthCheckRequest\x1a#.grpc.health.v1.HealthCheckResponseb\x06proto3') +) + + + +_HEALTHCHECKRESPONSE_SERVINGSTATUS = _descriptor.EnumDescriptor( + name='ServingStatus', + full_name='grpc.health.v1.HealthCheckResponse.ServingStatus', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='UNKNOWN', index=0, number=0, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='SERVING', index=1, number=1, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='NOT_SERVING', index=2, number=2, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=169, + serialized_end=227, +) +_sym_db.RegisterEnumDescriptor(_HEALTHCHECKRESPONSE_SERVINGSTATUS) + + +_HEALTHCHECKREQUEST = _descriptor.Descriptor( + name='HealthCheckRequest', + full_name='grpc.health.v1.HealthCheckRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='service', full_name='grpc.health.v1.HealthCheckRequest.service', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=39, + serialized_end=76, +) + + +_HEALTHCHECKRESPONSE = _descriptor.Descriptor( + name='HealthCheckResponse', + full_name='grpc.health.v1.HealthCheckResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='status', full_name='grpc.health.v1.HealthCheckResponse.status', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _HEALTHCHECKRESPONSE_SERVINGSTATUS, + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=79, + serialized_end=227, +) + +_HEALTHCHECKRESPONSE.fields_by_name['status'].enum_type = _HEALTHCHECKRESPONSE_SERVINGSTATUS +_HEALTHCHECKRESPONSE_SERVINGSTATUS.containing_type = _HEALTHCHECKRESPONSE +DESCRIPTOR.message_types_by_name['HealthCheckRequest'] = _HEALTHCHECKREQUEST +DESCRIPTOR.message_types_by_name['HealthCheckResponse'] = _HEALTHCHECKRESPONSE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +HealthCheckRequest = _reflection.GeneratedProtocolMessageType('HealthCheckRequest', (_message.Message,), dict( + DESCRIPTOR = _HEALTHCHECKREQUEST, + __module__ = 'health.health_pb2' + # @@protoc_insertion_point(class_scope:grpc.health.v1.HealthCheckRequest) + )) +_sym_db.RegisterMessage(HealthCheckRequest) + +HealthCheckResponse = _reflection.GeneratedProtocolMessageType('HealthCheckResponse', (_message.Message,), dict( + DESCRIPTOR = _HEALTHCHECKRESPONSE, + __module__ = 'health.health_pb2' + # @@protoc_insertion_point(class_scope:grpc.health.v1.HealthCheckResponse) + )) +_sym_db.RegisterMessage(HealthCheckResponse) + + + +_HEALTH = _descriptor.ServiceDescriptor( + name='Health', + full_name='grpc.health.v1.Health', + file=DESCRIPTOR, + index=0, + options=None, + serialized_start=229, + serialized_end=319, + methods=[ + _descriptor.MethodDescriptor( + name='Check', + full_name='grpc.health.v1.Health.Check', + index=0, + containing_service=None, + input_type=_HEALTHCHECKREQUEST, + output_type=_HEALTHCHECKRESPONSE, + options=None, + ), +]) +_sym_db.RegisterServiceDescriptor(_HEALTH) + +DESCRIPTOR.services_by_name['Health'] = _HEALTH + +try: + # THESE ELEMENTS WILL BE DEPRECATED. + # Please use the generated *_pb2_grpc.py files instead. + import grpc + from grpc.beta import implementations as beta_implementations + from grpc.beta import interfaces as beta_interfaces + from grpc.framework.common import cardinality + from grpc.framework.interfaces.face import utilities as face_utilities + + + class HealthStub(object): + # missing associated documentation comment in .proto file + pass + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Check = channel.unary_unary( + '/grpc.health.v1.Health/Check', + request_serializer=HealthCheckRequest.SerializeToString, + response_deserializer=HealthCheckResponse.FromString, + ) + + + class HealthServicer(object): + # missing associated documentation comment in .proto file + pass + + def Check(self, request, context): + # missing associated documentation comment in .proto file + pass + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + + def add_HealthServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Check': grpc.unary_unary_rpc_method_handler( + servicer.Check, + request_deserializer=HealthCheckRequest.FromString, + response_serializer=HealthCheckResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'grpc.health.v1.Health', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + class BetaHealthServicer(object): + """The Beta API is deprecated for 0.15.0 and later. + + It is recommended to use the GA API (classes and functions in this + file not marked beta) for all further purposes. This class was generated + only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0.""" + # missing associated documentation comment in .proto file + pass + def Check(self, request, context): + # missing associated documentation comment in .proto file + pass + context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) + + + class BetaHealthStub(object): + """The Beta API is deprecated for 0.15.0 and later. + + It is recommended to use the GA API (classes and functions in this + file not marked beta) for all further purposes. This class was generated + only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0.""" + # missing associated documentation comment in .proto file + pass + def Check(self, request, timeout, metadata=None, with_call=False, protocol_options=None): + # missing associated documentation comment in .proto file + pass + raise NotImplementedError() + Check.future = None + + + def beta_create_Health_server(servicer, pool=None, pool_size=None, default_timeout=None, maximum_timeout=None): + """The Beta API is deprecated for 0.15.0 and later. + + It is recommended to use the GA API (classes and functions in this + file not marked beta) for all further purposes. This function was + generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0""" + request_deserializers = { + ('grpc.health.v1.Health', 'Check'): HealthCheckRequest.FromString, + } + response_serializers = { + ('grpc.health.v1.Health', 'Check'): HealthCheckResponse.SerializeToString, + } + method_implementations = { + ('grpc.health.v1.Health', 'Check'): face_utilities.unary_unary_inline(servicer.Check), + } + server_options = beta_implementations.server_options(request_deserializers=request_deserializers, response_serializers=response_serializers, thread_pool=pool, thread_pool_size=pool_size, default_timeout=default_timeout, maximum_timeout=maximum_timeout) + return beta_implementations.server(method_implementations, options=server_options) + + + def beta_create_Health_stub(channel, host=None, metadata_transformer=None, pool=None, pool_size=None): + """The Beta API is deprecated for 0.15.0 and later. + + It is recommended to use the GA API (classes and functions in this + file not marked beta) for all further purposes. This function was + generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0""" + request_serializers = { + ('grpc.health.v1.Health', 'Check'): HealthCheckRequest.SerializeToString, + } + response_deserializers = { + ('grpc.health.v1.Health', 'Check'): HealthCheckResponse.FromString, + } + cardinalities = { + 'Check': cardinality.Cardinality.UNARY_UNARY, + } + stub_options = beta_implementations.stub_options(host=host, metadata_transformer=metadata_transformer, request_serializers=request_serializers, response_deserializers=response_deserializers, thread_pool=pool, thread_pool_size=pool_size) + return beta_implementations.dynamic_stub(channel, 'grpc.health.v1.Health', cardinalities, options=stub_options) +except ImportError: + pass +# @@protoc_insertion_point(module_scope) diff --git a/pkg/api/python/health/health_pb2_grpc.py b/pkg/api/python/health/health_pb2_grpc.py new file mode 100644 index 00000000000..8ca4614e3e5 --- /dev/null +++ b/pkg/api/python/health/health_pb2_grpc.py @@ -0,0 +1,46 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + +from health import health_pb2 as health_dot_health__pb2 + + +class HealthStub(object): + # missing associated documentation comment in .proto file + pass + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Check = channel.unary_unary( + '/grpc.health.v1.Health/Check', + request_serializer=health_dot_health__pb2.HealthCheckRequest.SerializeToString, + response_deserializer=health_dot_health__pb2.HealthCheckResponse.FromString, + ) + + +class HealthServicer(object): + # missing associated documentation comment in .proto file + pass + + def Check(self, request, context): + # missing associated documentation comment in .proto file + pass + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_HealthServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Check': grpc.unary_unary_rpc_method_handler( + servicer.Check, + request_deserializer=health_dot_health__pb2.HealthCheckRequest.FromString, + response_serializer=health_dot_health__pb2.HealthCheckResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'grpc.health.v1.Health', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) diff --git a/pkg/db/db_init.go b/pkg/db/db_init.go index b8eac1bdfca..df8602fdbeb 100644 --- a/pkg/db/db_init.go +++ b/pkg/db/db_init.go @@ -1,6 +1,7 @@ package db import ( + "fmt" "log" ) @@ -94,3 +95,12 @@ func (d *dbConn) DBInit() { log.Fatalf("Error creating earlystop_param table: %v", err) } } + +func (d *dbConn) SelectOne() error { + db := d.db + _, err := db.Exec(`SELECT 1`) + if err != nil { + return fmt.Errorf("Error `SELECT 1` probing: %v", err) + } + return nil +} diff --git a/pkg/db/interface.go b/pkg/db/interface.go index 1ebdd99c076..743682eb909 100644 --- a/pkg/db/interface.go +++ b/pkg/db/interface.go @@ -5,7 +5,6 @@ import ( "database/sql" "errors" "fmt" - "github.com/golang/protobuf/jsonpb" "log" "math/big" "math/rand" @@ -13,6 +12,8 @@ import ( "strings" "time" + "github.com/golang/protobuf/jsonpb" + api "github.com/kubeflow/katib/pkg/api" _ "github.com/go-sql-driver/mysql" @@ -20,7 +21,7 @@ import ( const ( dbDriver = "mysql" - dbNameTmpl = "root:%s@tcp(vizier-db:3306)/vizier" + dbNameTmpl = "root:%s@tcp(vizier-db:3306)/vizier?timeout=5s" mysqlTimeFmt = "2006-01-02 15:04:05.999999" connectInterval = 5 * time.Second @@ -43,6 +44,8 @@ type WorkerLog struct { type VizierDBInterface interface { DBInit() + SelectOne() error + GetStudyConfig(string) (*api.StudyConfig, error) GetStudyList() ([]string, error) CreateStudy(*api.StudyConfig) (string, error) @@ -115,24 +118,24 @@ func openSQLConn(driverName string, dataSourceName string, interval time.Duratio } } -func NewWithSQLConn(db *sql.DB) VizierDBInterface { +func NewWithSQLConn(db *sql.DB) (VizierDBInterface, error) { d := new(dbConn) d.db = db seed, err := crand.Int(crand.Reader, big.NewInt(1<<63-1)) if err != nil { - log.Fatalf("RNG initialization failed: %v", err) + return nil, fmt.Errorf("RNG initialization failed: %v", err) } // We can do the following instead, but it creates a locking issue //d.rng = rand.New(rand.NewSource(seed.Int64())) rand.Seed(seed.Int64()) - return d + return d, nil } -func New() VizierDBInterface { +func New() (VizierDBInterface, error) { db, err := openSQLConn(dbDriver, getDbName(), connectInterval, connectTimeout) if err != nil { - log.Fatalf("DB open failed: %v", err) + return nil, fmt.Errorf("DB open failed: %v", err) } return NewWithSQLConn(db) } diff --git a/pkg/db/interface_test.go b/pkg/db/interface_test.go index 206ce2cbc5d..040ea6127fa 100644 --- a/pkg/db/interface_test.go +++ b/pkg/db/interface_test.go @@ -38,7 +38,10 @@ func TestMain(m *testing.M) { os.Exit(1) } //mock.ExpectBegin() - dbInterface = NewWithSQLConn(db) + dbInterface, err = NewWithSQLConn(db) + if err != nil { + fmt.Printf("error NewWithSQLConn: %v\n", err) + } mock.ExpectExec("CREATE TABLE IF NOT EXISTS studies").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec("CREATE TABLE IF NOT EXISTS study_permissions").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec("CREATE TABLE IF NOT EXISTS trials").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1)) @@ -48,16 +51,25 @@ func TestMain(m *testing.M) { mock.ExpectExec("CREATE TABLE IF NOT EXISTS suggestion_param").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec("CREATE TABLE IF NOT EXISTS earlystop_param").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1)) dbInterface.DBInit() + err = dbInterface.SelectOne() + if err != nil { + fmt.Printf("error `SELECT 1` probing: %v\n", err) + } mysqlAddr := os.Getenv("TEST_MYSQL") if mysqlAddr != "" { mysql, err := sql.Open("mysql", "root:test123@tcp("+mysqlAddr+")/vizier") - if err != nil { fmt.Printf("error opening db: %v\n", err) os.Exit(1) } - mysqlInterface = NewWithSQLConn(mysql) + + mysqlInterface, err = NewWithSQLConn(mysql) + if err != nil { + fmt.Printf("error initializing db interface: %v\n", err) + os.Exit(1) + } + mysqlInterface.DBInit() } diff --git a/pkg/db/test/test.go b/pkg/db/test/test.go index 646eb0acf44..eec2d6dca6b 100644 --- a/pkg/db/test/test.go +++ b/pkg/db/test/test.go @@ -2,12 +2,16 @@ package main import ( "fmt" - "github.com/kubeflow/katib/pkg/db" "os" + + "github.com/kubeflow/katib/pkg/db" ) func main() { - dbInt := db.New() + dbInt, err := db.New() + if err != nil { + fmt.Printf("err: %v", err) + } study, err := dbInt.GetStudyConfig(os.Args[1]) if err != nil { fmt.Printf("err: %v", err) diff --git a/pkg/earlystopping/medianstopping.go b/pkg/earlystopping/medianstopping.go index a3afd7eac8c..5513087c3a0 100644 --- a/pkg/earlystopping/medianstopping.go +++ b/pkg/earlystopping/medianstopping.go @@ -3,11 +3,12 @@ package earlystopping import ( "context" "errors" - "github.com/kubeflow/katib/pkg/api" - vdb "github.com/kubeflow/katib/pkg/db" "log" "sort" "strconv" + + "github.com/kubeflow/katib/pkg/api" + vdb "github.com/kubeflow/katib/pkg/db" ) const ( @@ -28,8 +29,12 @@ type MedianStoppingRule struct { } func NewMedianStoppingRule() *MedianStoppingRule { + var err error m := &MedianStoppingRule{} - m.dbIf = vdb.New() + m.dbIf, err = vdb.New() + if err != nil { + log.Fatalf("Failed to open db connection: %v", err) + } return m } diff --git a/pkg/mock/db/db.go b/pkg/mock/db/db.go index b0b494cce83..bb681ed15c5 100644 --- a/pkg/mock/db/db.go +++ b/pkg/mock/db/db.go @@ -78,6 +78,12 @@ func (m *MockVizierDBInterface) DBInit() { m.ctrl.Call(m, "DBInit") } +func (m *MockVizierDBInterface) SelectOne() error { + ret := m.ctrl.Call(m, "SelectOne") + ret0, _ := ret[0].(error) + return ret0 +} + // DBInit indicates an expected call of DBInit func (mr *MockVizierDBInterfaceMockRecorder) DBInit() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DBInit", reflect.TypeOf((*MockVizierDBInterface)(nil).DBInit))