diff --git a/Dockerfile b/Dockerfile
index 8e7dcf7..786d04f 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,33 +1,17 @@
# Build the manager binary
-FROM golang:1.19 as builder
-ARG TARGETOS
-ARG TARGETARCH
+FROM golang:1.20 as builder
WORKDIR /workspace
# Copy the Go Modules manifests
-COPY go.mod go.mod
-COPY go.sum go.sum
-# cache deps before building and copying source so that we don't need to re-download as much
-# and so that source changes don't invalidate our downloaded layer
-RUN go mod download
+COPY . .
+RUN go mod tidy
-# Copy the go source
-COPY main.go main.go
-COPY apis/ apis/
-COPY controllers/ controllers/
-
-# Build
-# the GOARCH has not a default value to allow the binary be built according to the host where the command
-# was called. For example, if we call make docker-build in a local env which has the Apple Silicon M1 SO
-# the docker BUILDPLATFORM arg will be linux/arm64 when for Apple x86 it will be linux/amd64. Therefore,
-# by leaving it empty we can ensure that the container and binary shipped on it will have the same platform.
-RUN CGO_ENABLED=0 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} go build -a -o manager main.go
+RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -o manager main.go
# Use distroless as minimal base image to package the manager binary
# Refer to https://github.com/GoogleContainerTools/distroless for more details
-FROM gcr.io/distroless/static:nonroot
+FROM alpine:3
WORKDIR /
COPY --from=builder /workspace/manager .
-USER 65532:65532
ENTRYPOINT ["/manager"]
diff --git a/cmd/controller-manager/app/controller_manager.go b/cmd/controller-manager/app/controller_manager.go
index 80d2126..9ab90ea 100644
--- a/cmd/controller-manager/app/controller_manager.go
+++ b/cmd/controller-manager/app/controller_manager.go
@@ -1,31 +1,41 @@
package app
import (
- "context"
+ "fmt"
"os"
- "github.com/DataTunerX/utility-server/logging"
-
"github.com/DataTunerX/finetune-experiment-controller/cmd/controller-manager/app/options"
"github.com/DataTunerX/finetune-experiment-controller/internal/controller/finetune"
+ "github.com/DataTunerX/finetune-experiment-controller/pkg/util"
corev1beta1 "github.com/DataTunerX/meta-server/api/core/v1beta1"
extensionv1beta1 "github.com/DataTunerX/meta-server/api/extension/v1beta1"
finetunev1beta1 "github.com/DataTunerX/meta-server/api/finetune/v1beta1"
+ "github.com/DataTunerX/utility-server/logging"
"github.com/go-logr/zapr"
- "github.com/operator-framework/operator-lib/leader"
+ "github.com/open-policy-agent/cert-controller/pkg/rotator"
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
"github.com/spf13/pflag"
"k8s.io/apimachinery/pkg/runtime"
+ "k8s.io/apimachinery/pkg/types"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
+ "k8s.io/client-go/kubernetes"
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
_ "k8s.io/client-go/plugin/pkg/client/auth"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/healthz"
"sigs.k8s.io/controller-runtime/pkg/manager"
+ metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server"
+ "sigs.k8s.io/controller-runtime/pkg/webhook"
//+kubebuilder:scaffold:imports
)
-const LockName = "finetune-experiment-controller-lock"
+const (
+ LockName = "datatunerx-lock"
+ SecretName = "datatunerx-cert"
+ CaName = "datatunerx-ca"
+ CaOrganization = "datatunerx"
+ ServiceName = "finetune-experiment"
+)
var (
scheme = runtime.NewScheme()
@@ -41,38 +51,27 @@ func init() {
}
func NewControllerManager() (manager.Manager, error) {
- logging.ZLogger.Info("Start building controller manager")
opts := options.NewOptions()
flagSet := pflag.NewFlagSet("generic", pflag.ExitOnError)
opts.AddFlags(flagSet)
err := flagSet.Parse(os.Args[1:])
if err != nil {
logging.ZLogger.Errorf("Error parsing flags: %v", err)
- return nil, err
+ os.Exit(1)
}
logging.ZLogger.Info("Set logger for controller")
ctrl.SetLogger(zapr.NewLogger(logging.ZLogger.GetLogger()))
-
+ namespace := util.GetOperatorNamespace()
ctrOption := ctrl.Options{
- Scheme: scheme,
- MetricsBindAddress: opts.MetricsAddr,
- Port: 9443,
- HealthProbeBindAddress: opts.ProbeAddr,
- }
-
- if opts.LeaderElectLifeConfig.EnableLeaderLifeElect {
- err = leader.Become(context.TODO(), LockName)
- if err != nil {
- logging.ZLogger.Errorf("Failed to retry for leader lock: %v", err)
- return nil, err
- }
- } else {
- ctrOption.LeaderElection = false
- ctrOption.LeaderElectionID = LockName
- ctrOption.RetryPeriod = &opts.LeaderElectLeaseConfig.RetryPeriod
- ctrOption.RenewDeadline = &opts.LeaderElectLeaseConfig.RenewDeadline
- ctrOption.LeaseDuration = &opts.LeaderElectLeaseConfig.LeaseDuration
- ctrOption.LeaderElectionNamespace = opts.LeaderElectLeaseConfig.LeaderElectionNamespace
+ Scheme: scheme,
+ Metrics: metricsserver.Options{
+ BindAddress: opts.MetricsAddr,
+ },
+ WebhookServer: webhook.NewServer(webhook.Options{Port: 9443}),
+ HealthProbeBindAddress: opts.ProbeAddr,
+ LeaderElection: true,
+ LeaderElectionID: LockName,
+ LeaderElectionNamespace: namespace,
}
mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), ctrOption)
@@ -80,6 +79,60 @@ func NewControllerManager() (manager.Manager, error) {
logging.ZLogger.Errorf("Build controller manager failed: %v", err)
return nil, err
}
+ setupFinished := make(chan struct{})
+ if opts.EnableCertRotator {
+ logging.ZLogger.Info("Setting up cert rotation")
+ if err := rotator.AddRotator(mgr, &rotator.CertRotator{
+ SecretKey: types.NamespacedName{
+ Namespace: namespace,
+ Name: SecretName,
+ },
+ CAName: CaName,
+ CAOrganization: CaOrganization,
+ CertDir: "/tmp/k8s-webhook-server/serving-certs",
+ DNSName: fmt.Sprintf("%s.%s.svc", ServiceName, namespace),
+ IsReady: setupFinished,
+ Webhooks: []rotator.WebhookInfo{
+ {
+ Name: namespace + "-validating-webhook-configuration",
+ Type: rotator.Validating,
+ },
+ {
+ Name: namespace + "-mutating-webhook-configuration",
+ Type: rotator.Mutating,
+ },
+ },
+ }); err != nil {
+ logging.ZLogger.Errorf("Unable to set up cert rotation, %v", err)
+ os.Exit(1)
+ }
+ } else {
+ close(setupFinished)
+ }
+ go func() {
+ <-setupFinished
+ if err := (&finetunev1beta1.FinetuneJob{}).SetupWebhookWithManager(mgr); err != nil {
+ logging.ZLogger.Errorf("Unable to create webhook, %v", err)
+ os.Exit(1)
+
+ }
+ if err := (&finetunev1beta1.FinetuneExperiment{}).SetupWebhookWithManager(mgr); err != nil {
+ logging.ZLogger.Errorf("Unable to create webhook, %v", err)
+ os.Exit(1)
+ }
+ if err := (&corev1beta1.LLM{}).SetupWebhookWithManager(mgr); err != nil {
+ logging.ZLogger.Errorf("Unable to create webhook, %v", err)
+ os.Exit(1)
+ }
+ if err := (&corev1beta1.Hyperparameter{}).SetupWebhookWithManager(mgr); err != nil {
+ logging.ZLogger.Errorf("Unable to create webhook, %v", err)
+ os.Exit(1)
+ }
+ if err := (&extensionv1beta1.Dataset{}).SetupWebhookWithManager(mgr); err != nil {
+ logging.ZLogger.Errorf("Unable to create webhook, %v", err)
+ os.Exit(1)
+ }
+ }()
if err = (&finetune.FinetuneExperimentReconciler{
Client: mgr.GetClient(),
@@ -97,6 +150,16 @@ func NewControllerManager() (manager.Manager, error) {
logging.ZLogger.Errorf("Unable to create FinetuneJob controller, %v", err)
return nil, err
}
+ if err = (&finetune.FinetuneReconciler{
+ Log: logging.ZLogger,
+ Client: mgr.GetClient(),
+ Scheme: mgr.GetScheme(),
+ Clientset: kubernetes.NewForConfigOrDie(ctrl.GetConfigOrDie()),
+ Config: ctrl.GetConfigOrDie(),
+ }).SetupWithManager(mgr); err != nil {
+ logging.ZLogger.Errorf("Unable to create Finetune controller, %v", err)
+ return nil, err
+ }
//+kubebuilder:scaffold:builder
if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil {
diff --git a/cmd/controller-manager/app/options/options.go b/cmd/controller-manager/app/options/options.go
index ced8145..a27daee 100644
--- a/cmd/controller-manager/app/options/options.go
+++ b/cmd/controller-manager/app/options/options.go
@@ -13,29 +13,24 @@ const (
defaultMetricsAddr = ":8080"
defaultProbeAddr = ":8081"
defaultNamespace = "datatunerx-dev"
+ defaultCertRotator = true
)
type Options struct {
LeaderElectLeaseConfig LeaderElectLeaseConfig
- LeaderElectLifeConfig LeaderElectLifeConfig
MetricsAddr string
ProbeAddr string
+ EnableCertRotator bool
}
type LeaderElectLeaseConfig struct {
- LeaseDuration time.Duration
- RenewDeadline time.Duration
- RetryPeriod time.Duration
- LeaderElectionNamespace string
-}
-
-type LeaderElectLifeConfig struct {
- EnableLeaderLifeElect bool
+ LeaseDuration time.Duration
+ RenewDeadline time.Duration
+ RetryPeriod time.Duration
}
func NewOptions() *Options {
return &Options{
- LeaderElectLifeConfig: LeaderElectLifeConfig{},
LeaderElectLeaseConfig: LeaderElectLeaseConfig{},
}
}
@@ -46,9 +41,8 @@ func (o *Options) AddFlags(fs *pflag.FlagSet) {
}
fs.StringVar(&o.MetricsAddr, "metrics-bind-address", defaultMetricsAddr, "The address the metric endpoint binds to.")
fs.StringVar(&o.ProbeAddr, "health-probe-bind-address", defaultProbeAddr, "The address the probe endpoint binds to.")
- fs.StringVar(&o.LeaderElectLeaseConfig.LeaderElectionNamespace, "leader-life-namespace", defaultNamespace, "LeaderElectionNamespace determines the namespace in which the leader.")
- fs.BoolVar(&o.LeaderElectLifeConfig.EnableLeaderLifeElect, "enable-leader-life", false, "Enable or disable leader election life.")
fs.DurationVar(&o.LeaderElectLeaseConfig.LeaseDuration, "lease-duration", defaultLeaseDuration, "The duration that non-leader candidates will wait after observing a leadership renewal until attempting to acquire leadership of a led but unrenewed group.")
fs.DurationVar(&o.LeaderElectLeaseConfig.RenewDeadline, "renew-deadline", defaultRenewDeadline, "Duration the clients should wait between attempting to renew the lease of the lock.")
fs.DurationVar(&o.LeaderElectLeaseConfig.RetryPeriod, "retry-period", defaultRetryPeriod, "The time duration for the client to wait between attempts of acquiring a lock.")
+ fs.BoolVar(&o.EnableCertRotator, "cert-rotator", defaultCertRotator, "Automatically apply for a certificate for Webhooks.")
}
diff --git a/go.mod b/go.mod
index 1d38b41..1bac8fe 100644
--- a/go.mod
+++ b/go.mod
@@ -1,63 +1,65 @@
module github.com/DataTunerX/finetune-experiment-controller
-go 1.19
+go 1.20
require (
- github.com/DataTunerX/meta-server v0.0.0-20231128065201-7109bd13c9cb
- github.com/DataTunerX/utility-server v0.0.0-20231107081331-e4ac0bbd2db2
- github.com/go-logr/zapr v1.2.3
- github.com/operator-framework/operator-lib v0.11.0
+ github.com/DataTunerX/meta-server v0.0.0-20231225093059-13cc8ff65bdc
+ github.com/DataTunerX/utility-server v0.0.0-20231208092112-6224f8619737
+ github.com/duke-git/lancet/v2 v2.2.8
+ github.com/go-logr/zapr v1.2.4
+ github.com/open-policy-agent/cert-controller v0.10.0
github.com/ray-project/kuberay/ray-operator v1.0.0
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.17.0
- k8s.io/api v0.26.0
- k8s.io/apimachinery v0.26.0
- k8s.io/client-go v0.26.0
- sigs.k8s.io/controller-runtime v0.14.1
+ k8s.io/api v0.28.1
+ k8s.io/apimachinery v0.28.1
+ k8s.io/client-go v0.28.1
+ sigs.k8s.io/controller-runtime v0.16.1
)
require (
github.com/beorn7/perks v1.0.1 // indirect
- github.com/cespare/xxhash/v2 v2.1.2 // indirect
+ github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/emicklei/go-restful/v3 v3.9.0 // indirect
github.com/evanphx/json-patch/v5 v5.6.0 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
- github.com/go-logr/logr v1.2.3 // indirect
- github.com/go-openapi/jsonpointer v0.19.5 // indirect
- github.com/go-openapi/jsonreference v0.20.0 // indirect
- github.com/go-openapi/swag v0.19.14 // indirect
+ github.com/go-logr/logr v1.2.4 // indirect
+ github.com/go-openapi/jsonpointer v0.19.6 // indirect
+ github.com/go-openapi/jsonreference v0.20.2 // indirect
+ github.com/go-openapi/swag v0.22.3 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
- github.com/google/gnostic v0.5.7-v3refs // indirect
+ github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/go-cmp v0.5.9 // indirect
- github.com/google/gofuzz v1.1.0 // indirect
+ github.com/google/gofuzz v1.2.0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/imdario/mergo v0.3.12 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/magiconair/properties v1.8.7 // indirect
- github.com/mailru/easyjson v0.7.6 // indirect
- github.com/matttproud/golang_protobuf_extensions v1.0.2 // indirect
+ github.com/mailru/easyjson v0.7.7 // indirect
+ github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
- github.com/prometheus/client_golang v1.14.0 // indirect
- github.com/prometheus/client_model v0.3.0 // indirect
- github.com/prometheus/common v0.37.0 // indirect
- github.com/prometheus/procfs v0.8.0 // indirect
+ github.com/prometheus/client_golang v1.16.0 // indirect
+ github.com/prometheus/client_model v0.4.0 // indirect
+ github.com/prometheus/common v0.44.0 // indirect
+ github.com/prometheus/procfs v0.10.1 // indirect
github.com/sagikazarmark/locafero v0.3.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.10.0 // indirect
github.com/spf13/cast v1.5.1 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
- go.uber.org/multierr v1.10.0 // indirect
+ go.uber.org/atomic v1.11.0 // indirect
+ go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.26.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/net v0.17.0 // indirect
@@ -66,19 +68,19 @@ require (
golang.org/x/term v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/time v0.3.0 // indirect
- gomodules.xyz/jsonpatch/v2 v2.2.0 // indirect
+ gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
- k8s.io/apiextensions-apiserver v0.26.0 // indirect
- k8s.io/component-base v0.26.0 // indirect
- k8s.io/klog/v2 v2.80.1 // indirect
- k8s.io/kube-openapi v0.0.0-20221012153701-172d655c2280 // indirect
- k8s.io/utils v0.0.0-20221128185143-99ec85e7a448 // indirect
- sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2 // indirect
+ k8s.io/apiextensions-apiserver v0.28.1 // indirect
+ k8s.io/component-base v0.28.1 // indirect
+ k8s.io/klog/v2 v2.100.1 // indirect
+ k8s.io/kube-openapi v0.0.0-20230717233707-2695361300d9 // indirect
+ k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 // indirect
+ sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect
sigs.k8s.io/yaml v1.3.0 // indirect
)
diff --git a/go.sum b/go.sum
index b56df5e..34ba35f 100644
--- a/go.sum
+++ b/go.sum
@@ -38,24 +38,22 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
-github.com/DataTunerX/meta-server v0.0.0-20231128065201-7109bd13c9cb h1:ADOBX2XKCgG6cmTdYt4G0rt1pvDW6gVZHfrkNum8EQw=
-github.com/DataTunerX/meta-server v0.0.0-20231128065201-7109bd13c9cb/go.mod h1:MrA+U+PYANBfU8B43hrkJQ3WOIFPzUqowUO7s+KafvU=
-github.com/DataTunerX/utility-server v0.0.0-20231107081331-e4ac0bbd2db2 h1:3mBAWDqYrWtDk9xvIHDG/dN5zGcliwJnyvpWHFHcC+A=
-github.com/DataTunerX/utility-server v0.0.0-20231107081331-e4ac0bbd2db2/go.mod h1:qL3DYjQa7av0QkZoFrycHbpXHGQfBNEDke8uv+FdDn4=
-github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
-github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
-github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
-github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
-github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
+github.com/DataTunerX/meta-server v0.0.0-20231219035746-5584f4feca76 h1:rZ5BFKcF65Er5trKbxNy40XVHOFeC3g0tvMrQwweqvo=
+github.com/DataTunerX/meta-server v0.0.0-20231219035746-5584f4feca76/go.mod h1:ZegApA+ZAd5CNnWJ2YAOB876bGpnTxPDrpKL1Sa6yak=
+github.com/DataTunerX/meta-server v0.0.0-20231219100327-bc9f6700f0e6 h1:F647MVOf6G5lTeREDQBr0yv5Eg7+9w9YZINK3zsAwvs=
+github.com/DataTunerX/meta-server v0.0.0-20231219100327-bc9f6700f0e6/go.mod h1:ZegApA+ZAd5CNnWJ2YAOB876bGpnTxPDrpKL1Sa6yak=
+github.com/DataTunerX/meta-server v0.0.0-20231220083942-784fa5895015 h1:TSvZmvzdtt3RvFpvWiQqz07MfMhwfNkfb8uS7zNEyhk=
+github.com/DataTunerX/meta-server v0.0.0-20231220083942-784fa5895015/go.mod h1:ZegApA+ZAd5CNnWJ2YAOB876bGpnTxPDrpKL1Sa6yak=
+github.com/DataTunerX/meta-server v0.0.0-20231225093059-13cc8ff65bdc h1:94s49odKCAVLlokaSYBXltdonkrKKnWtHspzRezlY2A=
+github.com/DataTunerX/meta-server v0.0.0-20231225093059-13cc8ff65bdc/go.mod h1:ZegApA+ZAd5CNnWJ2YAOB876bGpnTxPDrpKL1Sa6yak=
+github.com/DataTunerX/utility-server v0.0.0-20231208092112-6224f8619737 h1:WYARNq3OZABZCWKtvApiKOg8/7CKYJImDynnOy6bfhs=
+github.com/DataTunerX/utility-server v0.0.0-20231208092112-6224f8619737/go.mod h1:6B12eUOsDcUZe/ayzODOOUG04JCJHgwXQOuXOa75whE=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
-github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
-github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
-github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
-github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
-github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
+github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
@@ -68,7 +66,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE=
+github.com/duke-git/lancet/v2 v2.2.8 h1:wlruXhliDe4zls1e2cYmz4qLc+WtcvrpcCnk1VJdEaA=
+github.com/duke-git/lancet/v2 v2.2.8/go.mod h1:zGa2R4xswg6EG9I6WnyubDbFO/+A/RROxIbXcwryTsc=
github.com/emicklei/go-restful/v3 v3.9.0 h1:XwGDlfxEnQZzuopoqxwSEllNcCOM9DhhFyhFIIGKwxE=
github.com/emicklei/go-restful/v3 v3.9.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
@@ -77,8 +76,7 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m
github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po=
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
-github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ=
-github.com/evanphx/json-patch v4.12.0+incompatible h1:4onqiflcdA9EOZ4RxV643DvftH5pOlLGNtQ5lPWQu84=
+github.com/evanphx/json-patch v5.6.0+incompatible h1:jBYDEEiFBPxA0v50tFdvOzQQTCvpL6mnFh5mB2/l16U=
github.com/evanphx/json-patch/v5 v5.6.0 h1:b91NhWfaz02IuVxO9faSllyAtNXHMPkC5J8sJCLunww=
github.com/evanphx/json-patch/v5 v5.6.0/go.mod h1:G79N1coSVB93tBe7j6PhzjmR3/2VvlbKOFpnXhI9Bw4=
github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
@@ -87,30 +85,18 @@ github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbS
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
-github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
-github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
-github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
-github.com/go-kit/log v0.2.0/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0=
-github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
-github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
-github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
-github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
-github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
-github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
-github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
-github.com/go-logr/zapr v1.2.3 h1:a9vnzlIBPQBBkeaR9IuMUfmVOrQlkoC4YfPoFkX3T7A=
-github.com/go-logr/zapr v1.2.3/go.mod h1:eIauM6P8qSvTw5o2ez6UEAfGjQKrxQTl5EoK+Qa2oG4=
-github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
-github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY=
-github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
-github.com/go-openapi/jsonreference v0.20.0 h1:MYlu0sBgChmCfJxxUKZ8g1cPWFOB37YSZqewK7OKeyA=
-github.com/go-openapi/jsonreference v0.20.0/go.mod h1:Ag74Ico3lPc+zR+qjn4XBUmXymS4zJbYVCZmcgkasdo=
-github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
-github.com/go-openapi/swag v0.19.14 h1:gm3vOOXfiuw5i9p5N9xJvfjvuofpyvLA9Wr6QfK5Fng=
-github.com/go-openapi/swag v0.19.14/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
-github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
-github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
+github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
+github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
+github.com/go-logr/zapr v1.2.4 h1:QHVo+6stLbfJmYGkQ7uGHUCu5hnAFAj6mDe6Ea0SeOo=
+github.com/go-logr/zapr v1.2.4/go.mod h1:FyHWQIzQORZ0QVE1BtVHv3cKtNLuXsbNLtpuhNapBOA=
+github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE=
+github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs=
+github.com/go-openapi/jsonreference v0.20.2 h1:3sVjiK66+uXK/6oQ8xgcRKcFgQ5KXa2KvnJRumpMGbE=
+github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k=
+github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g=
+github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
+github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
@@ -141,13 +127,12 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
-github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
-github.com/google/gnostic v0.5.7-v3refs h1:FhTMOKj2VhjpouxvWJAV1TL304uMlb9zcDqkl6cEI54=
-github.com/google/gnostic v0.5.7-v3refs/go.mod h1:73MKFl6jIHelAJNaBGFzt3SPtZULs9dYrGFt8OiIsHQ=
+github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I=
+github.com/google/gnostic-models v0.6.8/go.mod h1:5n7qKqH0f5wFt+aWF8CW6pZLLNOfYuF5OpfBSENuI8U=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
@@ -161,8 +146,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
-github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g=
-github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
+github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
+github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
@@ -176,6 +161,7 @@ github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hf
github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
+github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
@@ -194,24 +180,15 @@ github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
-github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
-github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
-github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
-github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
-github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
-github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
-github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
-github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
-github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
-github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
+github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
@@ -219,80 +196,50 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
-github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
-github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
-github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA=
-github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
-github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
-github.com/matttproud/golang_protobuf_extensions v1.0.2 h1:hAHbPm5IJGijwng3PWk09JkG9WeqChjprR5s9bBZ+OM=
-github.com/matttproud/golang_protobuf_extensions v1.0.2/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
+github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
+github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
+github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
+github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
-github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
-github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
-github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
-github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
-github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
-github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
-github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
-github.com/onsi/ginkgo/v2 v2.6.0 h1:9t9b9vRUbFq3C4qKFCGkVuq/fIHji802N1nrtkh1mNc=
-github.com/onsi/gomega v1.24.1 h1:KORJXNNTzJXzu4ScJWssJfJMnJ+2QJqhoQSRwNlze9E=
-github.com/operator-framework/operator-lib v0.11.0 h1:eYzqpiOfq9WBI4Trddisiq/X9BwCisZd3rIzmHRC9Z8=
-github.com/operator-framework/operator-lib v0.11.0/go.mod h1:RpyKhFAoG6DmKTDIwMuO6pI3LRc8IE9rxEYWy476o6g=
+github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU=
+github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI=
+github.com/open-policy-agent/cert-controller v0.10.0 h1:9hBJsnpHsBqKR7VVtOHW19mk/a1vQvje6+QSJeRHuDg=
+github.com/open-policy-agent/cert-controller v0.10.0/go.mod h1:4uRbBLY5DsPOog+a9pqk3JLxuuhrWsbUedQW65HcLTI=
+github.com/open-policy-agent/frameworks/constraint v0.0.0-20230822235116-f0b62fe1e4c4 h1:5dum5SLEz+95JDLkMls7Z7IDPjvSq3UhJSFe4f5einQ=
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
-github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
-github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
-github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
-github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
-github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
-github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
-github.com/prometheus/client_golang v1.14.0 h1:nJdhIvne2eSX/XRAFV9PcvFFRbrjbcTUj0VP62TMhnw=
-github.com/prometheus/client_golang v1.14.0/go.mod h1:8vpkKitgIVNcqrRBWh1C4TIUQgYNtG/XQE4E/Zae36Y=
-github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
-github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
+github.com/prometheus/client_golang v1.16.0 h1:yk/hx9hDbrGHovbci4BY+pRMfSuuat626eFsHb7tmT8=
+github.com/prometheus/client_golang v1.16.0/go.mod h1:Zsulrv/L9oM40tJ7T815tM89lFEugiJ9HzIqaAx4LKc=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
-github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
-github.com/prometheus/client_model v0.3.0 h1:UBgGFHqYdG/TPFD1B1ogZywDqEkwp3fBMvqdiQ7Xew4=
-github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w=
-github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
-github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
-github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
-github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls=
-github.com/prometheus/common v0.37.0 h1:ccBbHCgIiT9uSoFY0vX8H3zsNR5eLt17/RQLUvn8pXE=
-github.com/prometheus/common v0.37.0/go.mod h1:phzohg0JFMnBEFGxTDbfu3QyL5GI8gTQJFhYO5B3mfA=
-github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
-github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
-github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
-github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
-github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
-github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
-github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
+github.com/prometheus/client_model v0.4.0 h1:5lQXD3cAg1OXBf4Wq03gTrXHeaV0TQvGfUooCfx1yqY=
+github.com/prometheus/client_model v0.4.0/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU=
+github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY=
+github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY=
+github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg=
+github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM=
github.com/ray-project/kuberay/ray-operator v1.0.0 h1:i69nvbV7az2FG41VHQgxrmhD+SUl8ca+ek4RPbSE2Q0=
github.com/ray-project/kuberay/ray-operator v1.0.0/go.mod h1:7C7ebIkxtkmOX8w1iiLrKM1j4hkZs/Guzm3WdePk/yg=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
-github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
+github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/sagikazarmark/locafero v0.3.0 h1:zT7VEGWC2DTflmccN/5T1etyKvxSxpHsjb9cJvm4SvQ=
github.com/sagikazarmark/locafero v0.3.0/go.mod h1:w+v7UsPNFwzF1cHuOajOOzoq4U7v/ig1mpRjqV+Bu1U=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
-github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
-github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
-github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.10.0 h1:EaGW2JJh15aKOejeuJ+wpFSHnbd7GE6Wvp3TsNhb6LY=
@@ -303,19 +250,16 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.17.0 h1:I5txKw7MJasPL/BrfkbA0Jyo/oELqVmux4pR/UxOMfI=
github.com/spf13/viper v1.17.0/go.mod h1:BmMMMLQXSbcHK6KAOiFLz0l5JHrU89OdIRHvsk0+yVI=
-github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
-github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
-github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
@@ -324,6 +268,7 @@ github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
+github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
@@ -331,15 +276,16 @@ go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
-go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A=
-go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
+go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
+go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
+go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
+go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
-go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
-go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
-go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI=
+go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
+go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
+go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
-golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@@ -382,9 +328,9 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
+golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
@@ -392,7 +338,6 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
-golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@@ -415,10 +360,8 @@ golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
-golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
-golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
@@ -430,8 +373,6 @@ golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43/go.mod h1:KelEdhl1UZF7XfJ
golang.org/x/oauth2 v0.0.0-20201109201403-9fd604954f58/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
-golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
-golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b/go.mod h1:DAh4E804XQdzx2j+YRIaUnCqCV2RuMz24cGBJ5QYIrc=
golang.org/x/oauth2 v0.12.0 h1:smVPGxink+n1ZI5pkQa8y6fZT0RW0MgCO5bFpepy4B4=
golang.org/x/oauth2 v0.12.0/go.mod h1:A74bZ3aGXgCY0qaIC9Ahg6Lglin4AMAco8cIv9baba4=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -444,13 +385,11 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -459,7 +398,6 @@ golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -472,8 +410,6 @@ golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -481,19 +417,16 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210225134936-a50acf3fe073/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -526,7 +459,6 @@ golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgw
golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
@@ -561,12 +493,14 @@ golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4f
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
+golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
+golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-gomodules.xyz/jsonpatch/v2 v2.2.0 h1:4pT439QV83L+G9FkcCriY6EkpcK6r6bK+A5FBUMI7qY=
-gomodules.xyz/jsonpatch/v2 v2.2.0/go.mod h1:WXp+iVDkoLQqPudfQ9GBlwB2eZ5DKOnjQZCYdOS8GPY=
+gomodules.xyz/jsonpatch/v2 v2.4.0 h1:Ci3iUJyx9UeRx7CeFN8ARgGbkESwJK+KB9lLcWxY/Zw=
+gomodules.xyz/jsonpatch/v2 v2.4.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
@@ -624,7 +558,6 @@ google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6D
google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20201019141844-1ed22bb0c154/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20201109203340-2640f1f9cdfb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20201201144952-b05cb90ed32e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
@@ -661,29 +594,21 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
-gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
-gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
-gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
-gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
@@ -693,29 +618,30 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
-k8s.io/api v0.26.0 h1:IpPlZnxBpV1xl7TGk/X6lFtpgjgntCg8PJ+qrPHAC7I=
-k8s.io/api v0.26.0/go.mod h1:k6HDTaIFC8yn1i6pSClSqIwLABIcLV9l5Q4EcngKnQg=
-k8s.io/apiextensions-apiserver v0.26.0 h1:Gy93Xo1eg2ZIkNX/8vy5xviVSxwQulsnUdQ00nEdpDo=
-k8s.io/apiextensions-apiserver v0.26.0/go.mod h1:7ez0LTiyW5nq3vADtK6C3kMESxadD51Bh6uz3JOlqWQ=
-k8s.io/apimachinery v0.26.0 h1:1feANjElT7MvPqp0JT6F3Ss6TWDwmcjLypwoPpEf7zg=
-k8s.io/apimachinery v0.26.0/go.mod h1:tnPmbONNJ7ByJNz9+n9kMjNP8ON+1qoAIIC70lztu74=
-k8s.io/client-go v0.26.0 h1:lT1D3OfO+wIi9UFolCrifbjUUgu7CpLca0AD8ghRLI8=
-k8s.io/client-go v0.26.0/go.mod h1:I2Sh57A79EQsDmn7F7ASpmru1cceh3ocVT9KlX2jEZg=
-k8s.io/component-base v0.26.0 h1:0IkChOCohtDHttmKuz+EP3j3+qKmV55rM9gIFTXA7Vs=
-k8s.io/component-base v0.26.0/go.mod h1:lqHwlfV1/haa14F/Z5Zizk5QmzaVf23nQzCwVOQpfC8=
-k8s.io/klog/v2 v2.80.1 h1:atnLQ121W371wYYFawwYx1aEY2eUfs4l3J72wtgAwV4=
-k8s.io/klog/v2 v2.80.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0=
-k8s.io/kube-openapi v0.0.0-20221012153701-172d655c2280 h1:+70TFaan3hfJzs+7VK2o+OGxg8HsuBr/5f6tVAjDu6E=
-k8s.io/kube-openapi v0.0.0-20221012153701-172d655c2280/go.mod h1:+Axhij7bCpeqhklhUTe3xmOn6bWxolyZEeyaFpjGtl4=
-k8s.io/utils v0.0.0-20221128185143-99ec85e7a448 h1:KTgPnR10d5zhztWptI952TNtt/4u5h3IzDXkdIMuo2Y=
-k8s.io/utils v0.0.0-20221128185143-99ec85e7a448/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
+k8s.io/api v0.28.1 h1:i+0O8k2NPBCPYaMB+uCkseEbawEt/eFaiRqUx8aB108=
+k8s.io/api v0.28.1/go.mod h1:uBYwID+66wiL28Kn2tBjBYQdEU0Xk0z5qF8bIBqk/Dg=
+k8s.io/apiextensions-apiserver v0.28.1 h1:l2ThkBRjrWpw4f24uq0Da2HaEgqJZ7pcgiEUTKSmQZw=
+k8s.io/apiextensions-apiserver v0.28.1/go.mod h1:sVvrI+P4vxh2YBBcm8n2ThjNyzU4BQGilCQ/JAY5kGs=
+k8s.io/apimachinery v0.28.1 h1:EJD40og3GizBSV3mkIoXQBsws32okPOy+MkRyzh6nPY=
+k8s.io/apimachinery v0.28.1/go.mod h1:X0xh/chESs2hP9koe+SdIAcXWcQ+RM5hy0ZynB+yEvw=
+k8s.io/client-go v0.28.1 h1:pRhMzB8HyLfVwpngWKE8hDcXRqifh1ga2Z/PU9SXVK8=
+k8s.io/client-go v0.28.1/go.mod h1:pEZA3FqOsVkCc07pFVzK076R+P/eXqsgx5zuuRWukNE=
+k8s.io/component-base v0.28.1 h1:LA4AujMlK2mr0tZbQDZkjWbdhTV5bRyEyAFe0TJxlWg=
+k8s.io/component-base v0.28.1/go.mod h1:jI11OyhbX21Qtbav7JkhehyBsIRfnO8oEgoAR12ArIU=
+k8s.io/klog/v2 v2.100.1 h1:7WCHKK6K8fNhTqfBhISHQ97KrnJNFZMcQvKp7gP/tmg=
+k8s.io/klog/v2 v2.100.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0=
+k8s.io/kube-aggregator v0.28.1 h1:rvG4llYnQKHjj6YjjoBPEJxfD1uH0DJwkrJTNKGAaCs=
+k8s.io/kube-openapi v0.0.0-20230717233707-2695361300d9 h1:LyMgNKD2P8Wn1iAwQU5OhxCKlKJy0sHc+PcDwFB24dQ=
+k8s.io/kube-openapi v0.0.0-20230717233707-2695361300d9/go.mod h1:wZK2AVp1uHCp4VamDVgBP2COHZjqD1T68Rf0CM3YjSM=
+k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 h1:qY1Ad8PODbnymg2pRbkyMT/ylpTrCM8P2RJ0yroCyIk=
+k8s.io/utils v0.0.0-20230406110748-d93618cff8a2/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
-sigs.k8s.io/controller-runtime v0.14.1 h1:vThDes9pzg0Y+UbCPY3Wj34CGIYPgdmspPm2GIpxpzM=
-sigs.k8s.io/controller-runtime v0.14.1/go.mod h1:GaRkrY8a7UZF0kqFFbUKG7n9ICiTY5T55P1RiE3UZlU=
-sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2 h1:iXTIw73aPyC+oRdyqqvVJuloN1p0AC/kzH07hu3NE+k=
-sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0=
+sigs.k8s.io/controller-runtime v0.16.1 h1:+15lzrmHsE0s2kNl0Dl8cTchI5Cs8qofo5PGcPrV9z0=
+sigs.k8s.io/controller-runtime v0.16.1/go.mod h1:vpMu3LpI5sYWtujJOa2uPK61nB5rbwlN7BAB8aSLvGU=
+sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo=
+sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0=
sigs.k8s.io/structured-merge-diff/v4 v4.2.3 h1:PRbqxJClWWYMNV1dhaG4NsibJbArud9kFxnAMREiWFE=
sigs.k8s.io/structured-merge-diff/v4 v4.2.3/go.mod h1:qjx8mGObPmV2aSZepjQjbmb2ihdVs8cGKBraizNC69E=
sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo=
diff --git a/internal/controller/finetune/finetune_controller.go b/internal/controller/finetune/finetune_controller.go
new file mode 100644
index 0000000..756519b
--- /dev/null
+++ b/internal/controller/finetune/finetune_controller.go
@@ -0,0 +1,786 @@
+/*
+Copyright 2023.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package finetune
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "strconv"
+ "strings"
+ "time"
+
+ corev1beta1 "github.com/DataTunerX/meta-server/api/core/v1beta1"
+ extensionv1beta1 "github.com/DataTunerX/meta-server/api/extension/v1beta1"
+ finetunev1beta1 "github.com/DataTunerX/meta-server/api/finetune/v1beta1"
+ "github.com/DataTunerX/utility-server/logging"
+ rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
+ "github.com/ray-project/kuberay/ray-operator/controllers/ray/common"
+ batchv1 "k8s.io/api/batch/v1"
+ "k8s.io/api/core/v1"
+ apierrors "k8s.io/apimachinery/pkg/api/errors"
+ "k8s.io/apimachinery/pkg/api/resource"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/apimachinery/pkg/runtime"
+ "k8s.io/apimachinery/pkg/types"
+ "k8s.io/apimachinery/pkg/util/rand"
+ "k8s.io/client-go/kubernetes"
+ "k8s.io/client-go/kubernetes/scheme"
+ "k8s.io/client-go/rest"
+ "k8s.io/client-go/tools/remotecommand"
+ ctrl "sigs.k8s.io/controller-runtime"
+ "sigs.k8s.io/controller-runtime/pkg/client"
+ "sigs.k8s.io/controller-runtime/pkg/controller"
+ "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
+)
+
+const (
+ RayVersion = "v2.7.1"
+ DefaultRequeueDuration = 3 * time.Second
+ CheckpointPath = "/home/ray/checkpoint_path"
+)
+
+var metricsExportAddress = os.Getenv("METRICS_EXPORT_ADDRESS")
+var storagePath = os.Getenv("STORAGE_PATH")
+
+// FinetuneReconciler reconciles a Finetune object
+type FinetuneReconciler struct {
+ client.Client
+ Scheme *runtime.Scheme
+ Log logging.Logger
+ Clientset *kubernetes.Clientset
+ Config *rest.Config
+}
+
+//+kubebuilder:rbac:groups=finetune.datatunerx.io,resources=finetunes,verbs=get;list;watch;create;update;patch;delete
+//+kubebuilder:rbac:groups=finetune.datatunerx.io,resources=finetunes/status,verbs=get;update;patch
+//+kubebuilder:rbac:groups=finetune.datatunerx.io,resources=finetunes/finalizers,verbs=update
+
+// Reconcile is part of the main kubernetes reconciliation loop which aims to
+// move the current state of the cluster closer to the desired state.
+// TODO(user): Modify the Reconcile function to compare the state specified by
+// the Finetune object against the actual cluster state, and then
+// perform operations to make the cluster state reflect the state specified by
+// the user.
+//
+// For more details, check Reconcile and its Result here:
+// - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.14.1/pkg/reconcile
+func (r *FinetuneReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
+ r.Log.Infof("Reconciling Finetune: %+v", req.NamespacedName)
+ finetuneInstance := &finetunev1beta1.Finetune{}
+
+ err := r.Get(ctx, req.NamespacedName, finetuneInstance)
+
+ if err != nil {
+ if apierrors.IsNotFound(err) {
+ r.Log.Infof("Finetune: %+v not found. Ignoring since object must be deleted", req.NamespacedName)
+ return ctrl.Result{}, nil
+ }
+ // Error reading the object - requeue the request.
+ r.Log.Error("Failed to get Finetune")
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+
+ if finetuneInstance.ObjectMeta.DeletionTimestamp.IsZero() {
+ if !controllerutil.ContainsFinalizer(finetuneInstance, finetunev1beta1.FinetuneGroupFinalizer) {
+ controllerutil.AddFinalizer(finetuneInstance, finetunev1beta1.FinetuneGroupFinalizer)
+ if err := r.Update(context.Background(), finetuneInstance); err != nil {
+ r.Log.Errorf("Failed to update Finetune with finalizer: %v", err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+ }
+ } else {
+ r.Log.Infof("Finetune: %+v is being deleted", req.NamespacedName)
+ controllerutil.RemoveFinalizer(finetuneInstance, finetunev1beta1.FinetuneGroupFinalizer)
+ if err := r.Update(context.Background(), finetuneInstance); err != nil {
+ r.Log.Errorf("Failed to update Finetune without finalizer: %v", err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, nil
+ }
+
+ if finetuneInstance.Status.State == finetunev1beta1.FinetuneSuccessful {
+ r.Log.Infof("Finetune: %+v is Successful.", req.NamespacedName)
+ return ctrl.Result{}, nil
+ }
+
+ if finetuneInstance.Status.State == finetunev1beta1.FinetuneFailed {
+ r.Log.Infof("Finetune: %+v is Failed.", req.NamespacedName)
+ return ctrl.Result{}, nil
+ }
+
+ if finetuneInstance.Status.State == "" {
+ if err = r.updateFinetuneState(ctx, finetuneInstance, finetunev1beta1.FinetuneInit); err != nil {
+ r.Log.Errorf("Finetune %v update state error: %s", req.NamespacedName, err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+ }
+
+ rayJobInstance := &rayv1.RayJob{}
+
+ err = r.Get(ctx, req.NamespacedName, rayJobInstance)
+ if err != nil {
+ if apierrors.IsNotFound(err) {
+ r.Log.Info("RayJob not found. Create a new one.")
+ err = r.createRayJob(ctx, finetuneInstance)
+ if err != nil {
+ r.Log.Errorf("Failed to create RayJob: %v", err)
+ if err = r.updateFinetuneState(ctx, finetuneInstance, finetunev1beta1.FinetunePending); err != nil {
+ r.Log.Errorf("Finetune %v update state error: %s", req.NamespacedName, err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, nil
+ } else {
+ r.Log.Error("Failed to get RayJob")
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+ }
+
+ if err = r.updateFinetuneState(ctx, finetuneInstance, finetunev1beta1.FinetuneRunning); err != nil {
+ r.Log.Errorf("Finetune %v update state error: %s", req.NamespacedName, err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+
+ if rayJobInstance.Spec.ShutdownAfterJobFinishes == false {
+ rayJobInstance.Spec.ShutdownAfterJobFinishes = true
+ r.Log.Infof("RayJob %s/%s set ShutdownAfterJobFinishes true", rayJobInstance.Namespace, rayJobInstance.Name)
+ if err := r.Update(ctx, rayJobInstance); err != nil {
+ r.Log.Errorf("Failed to update RayJob %v: %v", req.NamespacedName, err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+ }
+
+ // check rayjob status
+ r.Log.Infof("RayJob %s/%s status is %s", rayJobInstance.Namespace, rayJobInstance.Name, rayJobInstance.Status.JobStatus)
+ if rayJobInstance.Status.JobStatus == "" {
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration * 10}, nil
+ }
+
+ // update rayjob info
+ if finetuneInstance.Status.RayJobInfo == nil {
+ rajJobInfo, err := r.getRayJobPodInfo(ctx, rayJobInstance)
+ if err != nil {
+ r.Log.Errorf("getRayJobPodInfo err: %s", err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+ finetuneInstance.Status.RayJobInfo = rajJobInfo
+
+ if err = r.Status().Update(ctx, finetuneInstance); err != nil {
+ r.Log.Errorf("Failed to update Finetune status %v: %v", req.NamespacedName, err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+ }
+
+ if isJobPendingOrRunning(rayJobInstance.Status.JobStatus) {
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration * 10}, nil
+ }
+
+ if isJobStoppedOrFailed(rayJobInstance.Status.JobStatus) {
+ if err = r.updateFinetuneState(ctx, finetuneInstance, finetunev1beta1.FinetuneFailed); err != nil {
+ r.Log.Errorf("Finetune %v update state error: %s", req.NamespacedName, err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, nil
+ }
+
+ if finetuneInstance.Status.LLMCheckpoint == nil {
+ headPod, err := r.getRayClusterHeadPod(ctx, rayJobInstance)
+ if err != nil {
+ r.Log.Errorf("getRayClusterHeadPod err: %s", err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+
+ checkpointPath, err := r.fetchPodFile(ctx, headPod, CheckpointPath)
+ if err != nil {
+ r.Log.Errorf("fetchPodFile err: %s", err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+
+ finetuneInstance.Status.LLMCheckpoint = &finetunev1beta1.Checkpoint{
+ LLMCheckpointRef: GenerateLLMCheckpointName(finetuneInstance.Name),
+ CheckpointPath: checkpointPath,
+ }
+ if err = r.Status().Update(ctx, finetuneInstance); err != nil {
+ r.Log.Errorf("Failed to update Finetune status %v: %v", req.NamespacedName, err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+
+ }
+
+ // create llmcheckpoit
+ if err := r.reconcileLLMCheckpoint(ctx, finetuneInstance); err != nil {
+ r.Log.Errorf("reconcileLLMCheckpoint err: %s", err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+
+ if err = r.updateFinetuneState(ctx, finetuneInstance, finetunev1beta1.FinetuneSuccessful); err != nil {
+ r.Log.Errorf("Finetune %v update state error: %s", req.NamespacedName, err)
+ return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err
+ }
+
+ return ctrl.Result{}, nil
+}
+
+func (r *FinetuneReconciler) getRayJobPodInfo(ctx context.Context, rayJob *rayv1.RayJob) (*finetunev1beta1.RayJobInfo, error) {
+ job := &batchv1.Job{}
+ if err := r.Get(ctx, client.ObjectKey{Namespace: rayJob.Namespace, Name: rayJob.Name}, job); err != nil {
+ return nil, err
+ }
+
+ jobPods := &v1.PodList{}
+ err := r.List(ctx, jobPods, client.InNamespace(rayJob.Namespace), client.MatchingLabels{"batch.kubernetes.io/job-name": rayJob.Name})
+ if err != nil {
+ return nil, err
+ }
+ if len(jobPods.Items) == 0 {
+ return nil, fmt.Errorf("RayJob: %s/%s has no pod", rayJob.Namespace, rayJob.Name)
+ }
+ var firstPod string
+ var podStartTime *metav1.Time
+ for _, pod := range jobPods.Items {
+ if podStartTime == nil || podStartTime.Time.After(pod.CreationTimestamp.Time) {
+ firstPod = pod.Name
+ podStartTime = &pod.CreationTimestamp
+ }
+ }
+
+ return &finetunev1beta1.RayJobInfo{RayJobPodName: firstPod, RayJobPodContainerName: "ray-job-submitter"}, nil
+}
+
+func (r *FinetuneReconciler) getRayClusterHeadPod(ctx context.Context, rayJob *rayv1.RayJob) (*v1.Pod, error) {
+ headPods := &v1.PodList{}
+ filterLabels := client.MatchingLabels{common.RayClusterLabelKey: rayJob.Status.RayClusterName, common.RayNodeTypeLabelKey: string(rayv1.HeadNode)}
+ if err := r.List(ctx, headPods, client.InNamespace(rayJob.Namespace), filterLabels); err != nil {
+ return nil, err
+ }
+ if len(headPods.Items) == 0 {
+ return nil, fmt.Errorf("RayCluster: %s/%s has no head node", rayJob.Namespace, rayJob.Status.RayClusterName)
+ }
+
+ return &headPods.Items[0], nil
+}
+
+func (r *FinetuneReconciler) fetchPodFile(ctx context.Context, pod *v1.Pod, filePath string) (string, error) {
+ execRequest := r.Clientset.CoreV1().RESTClient().Post().
+ Resource("pods").
+ Name(pod.Name).
+ Namespace(pod.Namespace).
+ SubResource("exec").
+ VersionedParams(&v1.PodExecOptions{
+ Container: pod.Spec.Containers[0].Name,
+ Command: []string{"cat", filePath},
+ Stdout: true,
+ Stderr: true,
+ }, scheme.ParameterCodec)
+
+ exec, err := remotecommand.NewSPDYExecutor(r.Config, "POST", execRequest.URL())
+ if err != nil {
+ return "", err
+ }
+ var stdout bytes.Buffer
+ var stderr bytes.Buffer
+ err = exec.StreamWithContext(ctx, remotecommand.StreamOptions{
+ Stdout: &stdout,
+ Stderr: &stderr,
+ })
+ if err != nil {
+ return "", err
+ }
+ return stdout.String(), nil
+}
+
+func (r *FinetuneReconciler) reconcileLLMCheckpoint(ctx context.Context, finetune *finetunev1beta1.Finetune) error {
+ ns := finetune.Namespace
+ checkpointNamespacedName := types.NamespacedName{
+ Namespace: ns,
+ Name: finetune.Status.LLMCheckpoint.LLMCheckpointRef,
+ }
+
+ checkpoint := &corev1beta1.LLMCheckpoint{}
+ err := r.Get(ctx, checkpointNamespacedName, checkpoint)
+ if err == nil {
+ return nil
+ }
+
+ if apierrors.IsNotFound(err) {
+ r.Log.Infof("LLMCheckpoint: %s/%s is not found", ns, finetune.Name)
+
+ datasetInstance, err := r.getDataset(ctx, types.NamespacedName{ns, finetune.Spec.Dataset})
+ if err != nil {
+ return err
+ }
+
+ hyperparameterInstance, err := r.getHyperparameter(ctx, types.NamespacedName{ns, finetune.Spec.Hyperparameter.HyperparameterRef})
+ if err != nil {
+ return err
+ }
+
+ llmInstance, err := r.getLLM(ctx, types.NamespacedName{ns, finetune.Spec.LLM})
+ if err != nil {
+ return err
+ }
+
+ llmCheckpointInstance, err := generateLLMCheckpoint(checkpointNamespacedName, finetune, datasetInstance, hyperparameterInstance, llmInstance)
+ if err != nil {
+ return err
+ }
+
+ // Set controller reference
+ //if err := controllerutil.SetControllerReference(finetune, llmCheckpointInstance, r.Scheme); err != nil {
+ // return err
+ //}
+
+ return r.Create(ctx, llmCheckpointInstance)
+
+ }
+
+ return err
+}
+
+func (r *FinetuneReconciler) getLLM(ctx context.Context, namespacedName types.NamespacedName) (*corev1beta1.LLM, error) {
+ llmInstance := &corev1beta1.LLM{}
+ err := r.Get(ctx, namespacedName, llmInstance)
+ if err != nil {
+ r.Log.Errorf("Failed to get LLM %v", namespacedName)
+ return nil, err
+ }
+ return llmInstance, nil
+}
+
+func (r *FinetuneReconciler) getDataset(ctx context.Context, namespacedName types.NamespacedName) (*extensionv1beta1.Dataset, error) {
+ datasetInstance := &extensionv1beta1.Dataset{}
+ err := r.Get(ctx, namespacedName, datasetInstance)
+ if err != nil {
+ r.Log.Errorf("Failed to get Dataset %v", namespacedName)
+ return nil, err
+ }
+ return datasetInstance, nil
+}
+
+func (r *FinetuneReconciler) getHyperparameter(ctx context.Context, namespacedName types.NamespacedName) (*corev1beta1.Hyperparameter, error) {
+ hyperparameterInstance := &corev1beta1.Hyperparameter{}
+ err := r.Get(ctx, namespacedName, hyperparameterInstance)
+ if err != nil {
+ r.Log.Errorf("Failed to get Hyperparameter %v", namespacedName)
+ return nil, err
+ }
+ return hyperparameterInstance, nil
+}
+
+// createRayJob will create the rayjob
+func (r *FinetuneReconciler) createRayJob(ctx context.Context, finetune *finetunev1beta1.Finetune) error {
+ ns := finetune.Namespace
+
+ datasetInstance, err := r.getDataset(ctx, types.NamespacedName{ns, finetune.Spec.Dataset})
+ if err != nil {
+ finetune.Status.State = finetunev1beta1.FinetunePending
+ if err := r.Status().Update(ctx, finetune); err != nil {
+ return err
+ }
+ return err
+ }
+
+ hyperparameterInstance, err := r.getHyperparameter(ctx, types.NamespacedName{ns, finetune.Spec.Hyperparameter.HyperparameterRef})
+ if err != nil {
+ finetune.Status.State = finetunev1beta1.FinetunePending
+ if err := r.Status().Update(ctx, finetune); err != nil {
+ return err
+ }
+ return err
+ }
+
+ newParameters := updateHyperparameters(&hyperparameterInstance.Spec.Parameters, finetune.Spec.Hyperparameter.Overrides)
+ r.Log.Debugf("newParameters: %+v", newParameters)
+ rayJobEntrypoint, err := getRayJobEntrypoint(ctx, finetune, datasetInstance, newParameters)
+ if err != nil {
+ return err
+ }
+
+ r.Log.Info("create ray cluster")
+ rayJobInstance, err := generateRayJob(ctx, &types.NamespacedName{ns, finetune.Name}, rayJobEntrypoint, finetune)
+ if err != nil {
+ return err
+ }
+
+ // Set controller reference
+ if err := controllerutil.SetControllerReference(finetune, rayJobInstance, r.Scheme); err != nil {
+ return err
+ }
+
+ return r.Create(ctx, rayJobInstance)
+}
+
+// updateFinetuneState is a method of the FinetuneReconciler struct.
+// It updates the state of a Finetune instance and logs the new state.
+//
+// Parameters:
+// ctx: The context within which the function is called. Used for timeout and cancellation signals.
+// instance: The Finetune instance whose state is to be updated.
+// finetuneState: The new state to be set for the Finetune instance.
+//
+// Returns:
+// error: An error object that describes an error that occurred during the function's execution. Returns nil if the function executed successfully.
+func (r *FinetuneReconciler) updateFinetuneState(ctx context.Context, instance *finetunev1beta1.Finetune, finetuneState finetunev1beta1.FinetuneState) error {
+ // If the current state is the same as the new state, return nil
+ if instance.Status.State == finetuneState {
+ return nil
+ }
+ // Update the state of the Finetune instance
+ instance.Status.State = finetuneState
+ // Log the new state
+ r.Log.Infof("Update Finetune CR Status.State: %s", finetuneState)
+ // Update the status of the Finetune instance in the Kubernetes API and return any error that occurs
+ return r.Status().Update(ctx, instance)
+}
+
+func getRayJobEntrypoint(ctx context.Context, finetune *finetunev1beta1.Finetune, dataset *extensionv1beta1.Dataset, parameters *corev1beta1.Parameters) (string, error) {
+ // TODO check parameters include blank
+ replicas := int32(finetune.Spec.Node)
+ if replicas <= 0 {
+ replicas = 1
+ }
+ entrypoint := []string{"python"}
+ entrypoint = append(entrypoint, "/tuning/train.py")
+
+ finetunePath := finetune.Spec.Image.Path
+ if finetunePath == "" {
+ return "", fmt.Errorf("%s/%s: finetune.Spec.Image.Path is required", finetune.Namespace, finetune.Name)
+ }
+ entrypoint = append(entrypoint, "--model_name_or_path", finetunePath)
+
+ entrypoint = append(entrypoint, "--train_path", dataset.Spec.DatasetMetadata.DatasetInfo.Subsets[0].Splits.Train.File)
+
+ if dataset.Spec.DatasetMetadata.DatasetInfo.Subsets[0].Splits.Validate != nil && dataset.Spec.DatasetMetadata.DatasetInfo.Subsets[0].Splits.Validate.File != "" {
+ entrypoint = append(entrypoint, "--evaluation_path", dataset.Spec.DatasetMetadata.DatasetInfo.Subsets[0].Splits.Validate.File)
+ }
+
+ featuresMapJson, err := getFeaturesMapJson(dataset.Spec.DatasetMetadata.DatasetInfo.Features)
+ if err != nil {
+ return "", err
+ }
+ if featuresMapJson != "" {
+ entrypoint = append(entrypoint, "--columns", strconv.Quote(featuresMapJson))
+ }
+
+ entrypoint = append(entrypoint, "--output_dir", "result")
+ entrypoint = append(entrypoint, "--deepspeed", "/tuning/ds_config.json")
+ entrypoint = append(entrypoint, "--lora_target", "q_proj,v_proj")
+ entrypoint = append(entrypoint, "--lr_scheduler_type", string(parameters.Scheduler))
+ entrypoint = append(entrypoint, "--optim", string(parameters.Optimizer))
+
+ quantization := ""
+ if parameters.Int8 {
+ quantization = "int8"
+ } else if parameters.Int4 {
+ quantization = "int4"
+ }
+ if quantization != "" {
+ entrypoint = append(entrypoint, "--quantization", quantization)
+ }
+
+ entrypoint = append(entrypoint, "--lora_r", parameters.LoRA_R)
+ entrypoint = append(entrypoint, "--lora_alpha", parameters.LoRA_Alpha)
+ entrypoint = append(entrypoint, "--lora_dropout", parameters.LoRA_Dropout)
+ entrypoint = append(entrypoint, "--learning_rate", parameters.LearningRate)
+ entrypoint = append(entrypoint, "--num_train_epochs", fmt.Sprintf("%d", parameters.Epochs))
+ entrypoint = append(entrypoint, "--block_size", fmt.Sprintf("%d", parameters.BlockSize))
+ entrypoint = append(entrypoint, "--per_device_train_batch_size ", fmt.Sprintf("%d", parameters.BatchSize))
+ entrypoint = append(entrypoint, "--warmup_ratio", parameters.WarmupRatio)
+ entrypoint = append(entrypoint, "--weight_decay", parameters.WeightDecay)
+ entrypoint = append(entrypoint, "--gradient_accumulation_steps", fmt.Sprintf("%d", parameters.GradAccSteps))
+ entrypoint = append(entrypoint, "--fp16", fmt.Sprintf("%t", parameters.FP16))
+ entrypoint = append(entrypoint, "--num_workers", fmt.Sprintf("%d", replicas))
+ entrypoint = append(entrypoint, "--storage_path", storagePath)
+
+ if metricsExportAddress != "" {
+ entrypoint = append(entrypoint, "--metrics_export_address", metricsExportAddress)
+ entrypoint = append(entrypoint, "--uid", fmt.Sprintf("%s", finetune.UID))
+
+ }
+ return strings.Join(entrypoint, " "), nil
+}
+
+func generateRayJob(ctx context.Context, namespacedName *types.NamespacedName, entrypoint string, finetune *finetunev1beta1.Finetune) (*rayv1.RayJob, error) {
+ replicas := int32(finetune.Spec.Node)
+ if replicas <= 0 {
+ replicas = 1
+ }
+ if finetune.Spec.Image.Name == "" {
+ return nil, fmt.Errorf("%s/%s: finetune.Spec.Image.Name is required", finetune.Namespace, finetune.Name)
+ }
+
+ var rayJobInstance = &rayv1.RayJob{
+ TypeMeta: metav1.TypeMeta{
+ Kind: "RayJob",
+ APIVersion: "ray.io/v1",
+ },
+ ObjectMeta: metav1.ObjectMeta{
+ Name: namespacedName.Name,
+ Namespace: namespacedName.Namespace,
+ },
+ Spec: rayv1.RayJobSpec{
+ //ShutdownAfterJobFinishes: true,
+ Entrypoint: entrypoint,
+ RayClusterSpec: &rayv1.RayClusterSpec{
+ RayVersion: RayVersion,
+ HeadGroupSpec: rayv1.HeadGroupSpec{
+ ServiceType: "NodePort",
+ HeadService: nil,
+ EnableIngress: nil,
+ Replicas: nil,
+ RayStartParams: map[string]string{
+ "dashboard-host": "0.0.0.0",
+ "num-gpus": "0",
+ },
+ Template: v1.PodTemplateSpec{
+ Spec: v1.PodSpec{
+ Containers: []v1.Container{
+ v1.Container{
+ Name: "ray-head",
+ Image: finetune.Spec.Image.Name,
+ ImagePullPolicy: finetune.Spec.Image.ImagePullPolicy,
+ Ports: []v1.ContainerPort{
+ v1.ContainerPort{
+ Name: "gcs-server",
+ ContainerPort: 6379,
+ },
+ v1.ContainerPort{
+ Name: "dashboard",
+ ContainerPort: 8265,
+ },
+ v1.ContainerPort{
+ Name: "client",
+ ContainerPort: 10001,
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ WorkerGroupSpecs: []rayv1.WorkerGroupSpec{rayv1.WorkerGroupSpec{
+ GroupName: "finetune-group",
+ Replicas: &replicas,
+ RayStartParams: map[string]string{},
+ Template: v1.PodTemplateSpec{
+ Spec: v1.PodSpec{
+ Containers: []v1.Container{
+ v1.Container{
+ Name: "ray-worker",
+ Image: finetune.Spec.Image.Name,
+ ImagePullPolicy: finetune.Spec.Image.ImagePullPolicy,
+ Lifecycle: &v1.Lifecycle{
+ PreStop: &v1.LifecycleHandler{
+ Exec: &v1.ExecAction{
+ Command: []string{
+ "/bin/sh", "-c", "ray stop",
+ },
+ },
+ },
+ },
+ Resources: v1.ResourceRequirements{
+ Requests: v1.ResourceList{
+ "nvidia.com/gpu": resource.MustParse("1"),
+ },
+ Limits: v1.ResourceList{
+ "nvidia.com/gpu": resource.MustParse("1"),
+ },
+ },
+ },
+ },
+ },
+ },
+ ScaleStrategy: rayv1.ScaleStrategy{},
+ }},
+ EnableInTreeAutoscaling: nil,
+ AutoscalerOptions: nil,
+ HeadServiceAnnotations: nil,
+ },
+ ClusterSelector: nil,
+ Suspend: false,
+ },
+ }
+ return rayJobInstance, nil
+}
+
+func generateLLMCheckpoint(checkpointNamespacedName types.NamespacedName, finetune *finetunev1beta1.Finetune, dataset *extensionv1beta1.Dataset, hyperparameter *corev1beta1.Hyperparameter, llm *corev1beta1.LLM) (*corev1beta1.LLMCheckpoint, error) {
+ if finetune.Status.LLMCheckpoint == nil || finetune.Status.LLMCheckpoint.CheckpointPath == "" {
+ return nil, fmt.Errorf("CheckpointPath is nil")
+ }
+
+ var llmCheckpointInstance = &corev1beta1.LLMCheckpoint{
+ TypeMeta: metav1.TypeMeta{
+ Kind: "LLMCheckpoint",
+ APIVersion: "core.datatunerx.io/v1beta1",
+ },
+ ObjectMeta: metav1.ObjectMeta{
+ Name: checkpointNamespacedName.Name,
+ Namespace: checkpointNamespacedName.Namespace,
+ },
+ Spec: corev1beta1.LLMCheckpointSpec{
+ LLM: &corev1beta1.LLMRefSpec{
+ LLMRef: finetune.Spec.LLM,
+ Spec: llm.Spec.DeepCopy(),
+ },
+ Dataset: &corev1beta1.DatasetRefSpec{
+ DatasetRef: finetune.Spec.Dataset,
+ Spec: dataset.Spec.DeepCopy(),
+ },
+ Hyperparameter: &corev1beta1.HyperparameterRefSpec{
+ HyperparameterRef: finetune.Spec.Hyperparameter.HyperparameterRef,
+ Spec: hyperparameter.Spec.DeepCopy(),
+ },
+ Image: &finetune.Spec.Image,
+ Checkpoint: finetune.Status.LLMCheckpoint.CheckpointPath,
+ },
+ }
+ return llmCheckpointInstance, nil
+}
+
+func getFeaturesMapJson(features []extensionv1beta1.Feature) (string, error) {
+ if features == nil {
+ return "", nil
+ }
+ featuresMap := make(map[string]string)
+ for _, feature := range features {
+ if feature.Name == "instruction" {
+ featuresMap["instruction"] = feature.MapTo
+ continue
+ }
+ if feature.Name == "response" {
+ featuresMap["response"] = feature.MapTo
+ }
+ }
+
+ if len(featuresMap) == 0 {
+ return "", nil
+ }
+
+ jsonData, err := json.Marshal(featuresMap)
+ if err != nil {
+ return "", err
+ }
+
+ return string(jsonData), nil
+}
+
+func updateHyperparameters(parameters *corev1beta1.Parameters, overrides *finetunev1beta1.Parameters) *corev1beta1.Parameters {
+ newParameters := parameters.DeepCopy()
+
+ if overrides == nil {
+ return newParameters
+ }
+
+ if overrides.Scheduler != "" {
+ newParameters.Scheduler = overrides.Scheduler
+ }
+
+ if overrides.Optimizer != "" {
+ newParameters.Optimizer = overrides.Optimizer
+ }
+
+ if overrides.Int4 != nil {
+ newParameters.Int4 = *overrides.Int4
+ }
+
+ if overrides.Int8 != nil {
+ newParameters.Int8 = *overrides.Int8
+ }
+
+ if overrides.LoRA_R != nil {
+ newParameters.LoRA_R = *overrides.LoRA_R
+ }
+
+ if overrides.LoRA_Alpha != nil {
+ newParameters.LoRA_Alpha = *overrides.LoRA_Alpha
+ }
+
+ if overrides.LoRA_Dropout != nil {
+ newParameters.LoRA_Dropout = *overrides.LoRA_Dropout
+ }
+
+ if overrides.LearningRate != nil {
+ newParameters.LearningRate = *overrides.LearningRate
+ }
+
+ if overrides.Epochs != 0 {
+ newParameters.Epochs = overrides.Epochs
+ }
+
+ if overrides.BlockSize != 0 {
+ newParameters.BlockSize = overrides.BlockSize
+ }
+
+ if overrides.BatchSize != 0 {
+ newParameters.BatchSize = overrides.BatchSize
+ }
+
+ if overrides.WarmupRatio != nil {
+ newParameters.WarmupRatio = *overrides.WarmupRatio
+ }
+
+ if overrides.WeightDecay != nil {
+ newParameters.WeightDecay = *overrides.WeightDecay
+ }
+
+ if overrides.GradAccSteps != 0 {
+ newParameters.GradAccSteps = overrides.GradAccSteps
+ }
+
+ if overrides.TrainerType != nil {
+ newParameters.TrainerType = *overrides.TrainerType
+ }
+
+ if overrides.PEFT != nil {
+ newParameters.PEFT = *overrides.PEFT
+ }
+
+ if overrides.FP16 != nil {
+ newParameters.FP16 = *overrides.FP16
+ }
+
+ return newParameters
+}
+
+// isJobPendingOrRunning indicates whether the job is running.
+func isJobPendingOrRunning(status rayv1.JobStatus) bool {
+ return (status == rayv1.JobStatusPending) || (status == rayv1.JobStatusRunning)
+}
+
+// isJobPendingOrRunning indicates whether the job is running.
+func isJobStoppedOrFailed(status rayv1.JobStatus) bool {
+ return (status == rayv1.JobStatusStopped) || (status == rayv1.JobStatusFailed)
+}
+
+// GenerateLLMCheckpointName generates a LLMCheckpoint name from Finetune name
+func GenerateLLMCheckpointName(finetuneName string) string {
+ return fmt.Sprintf("%s-%s", finetuneName, rand.String(5))
+}
+
+// SetupWithManager sets up the controller with the Manager.
+func (r *FinetuneReconciler) SetupWithManager(mgr ctrl.Manager) error {
+ return ctrl.NewControllerManagedBy(mgr).
+ For(&finetunev1beta1.Finetune{}).
+ Owns(&rayv1.RayJob{}).
+ Owns(&corev1beta1.LLMCheckpoint{}).
+ WithOptions(controller.Options{CacheSyncTimeout: 10 * time.Second}).
+ Complete(r)
+}
diff --git a/internal/controller/finetune/finetuneexperiment_controller.go b/internal/controller/finetune/finetuneexperiment_controller.go
index 65115d3..e04b6db 100644
--- a/internal/controller/finetune/finetuneexperiment_controller.go
+++ b/internal/controller/finetune/finetuneexperiment_controller.go
@@ -18,10 +18,11 @@ package finetune
import (
"context"
- "fmt"
"reflect"
+ "sort"
"time"
+ "github.com/DataTunerX/finetune-experiment-controller/pkg/util"
"github.com/DataTunerX/finetune-experiment-controller/pkg/util/handlererr"
finetunev1beta1 "github.com/DataTunerX/meta-server/api/finetune/v1beta1"
"github.com/DataTunerX/utility-server/logging"
@@ -37,7 +38,6 @@ import (
"sigs.k8s.io/controller-runtime/pkg/event"
"sigs.k8s.io/controller-runtime/pkg/handler"
"sigs.k8s.io/controller-runtime/pkg/predicate"
- "sigs.k8s.io/controller-runtime/pkg/source"
)
// FinetuneExperimentReconciler reconciles a FinetuneExperiment object
@@ -83,7 +83,25 @@ func (r *FinetuneExperimentReconciler) Reconcile(ctx context.Context, req ctrl.R
}
}
- if finetuneExperiment.Spec.Pending {
+ if finetuneExperiment.Spec.Pending && finetuneExperiment.Status.State != finetunev1beta1.FinetuneExperimentPending {
+ for i := range finetuneExperiment.Spec.FinetuneJobs {
+ finetuneJob := finetuneExperiment.Spec.FinetuneJobs[i]
+ existFinetuneJob := &finetunev1beta1.FinetuneJob{}
+ if err := r.Client.Get(ctx, types.NamespacedName{
+ Name: finetuneJob.Name,
+ Namespace: finetuneExperiment.Namespace,
+ }, existFinetuneJob); err != nil {
+ if errors.IsNotFound(err) {
+ r.Log.Infof("FinetuneJob %s/%s not found, continue", finetuneExperiment.Namespace, finetuneJob.Name)
+ continue
+ }
+ return handlererr.HandlerErr(err)
+ }
+ if err := r.Client.Delete(ctx, existFinetuneJob); err != nil {
+ return handlererr.HandlerErr(err)
+ }
+ }
+ finetuneExperiment.Status.JobsStatus = make([]*finetunev1beta1.FinetuneJobStatusSetting, 0)
finetuneExperiment.Status.State = finetunev1beta1.FinetuneExperimentPending
finetuneExperiment.Status.Stats = metav1.Now().Format("2006-01-02 15:04:05")
if err := r.Client.Status().Update(ctx, finetuneExperiment); err != nil {
@@ -91,14 +109,19 @@ func (r *FinetuneExperimentReconciler) Reconcile(ctx context.Context, req ctrl.R
return handlererr.HandlerErr(err)
}
return handlererr.HandlerErr(nil)
+ } else if finetuneExperiment.Spec.Pending {
+ return handlererr.HandlerErr(nil)
}
+ if finetuneExperiment.Status.State == "" {
+ finetuneExperiment.Status.State = finetunev1beta1.FinetuneExperimentProcessing
+ if err := r.Client.Status().Update(ctx, finetuneExperiment); err != nil {
+ r.Log.Errorf("Update fineExperiment %s/%s status failed", finetuneExperiment.Name, finetuneExperiment.Namespace)
+ return handlererr.HandlerErr(err)
+ }
+ }
for i := range finetuneExperiment.Spec.FinetuneJobs {
finetuneJob := finetuneExperiment.Spec.FinetuneJobs[i]
- if finetuneJob.Name == "" {
- finetuneJob.Name = fmt.Sprintf("%s-%s-%d", finetuneExperiment.Name, "finetunejob", i+1)
- finetuneExperiment.Spec.FinetuneJobs[i].Name = fmt.Sprintf("%s-%s-%d", finetuneExperiment.Name, "finetunejob", i+1)
- }
existFinetuneJob := &finetunev1beta1.FinetuneJob{}
if err := r.Client.Get(ctx, types.NamespacedName{
Name: finetuneJob.Name,
@@ -127,52 +150,71 @@ func (r *FinetuneExperimentReconciler) Reconcile(ctx context.Context, req ctrl.R
}
}
}
- if finetuneExperiment.Status.State == finetunev1beta1.FinetuneExperimentProcessing {
- for i := range finetuneExperiment.Spec.FinetuneJobs {
- if finetuneExperiment.Spec.FinetuneJobs[i].Name == "" {
- finetuneExperiment.Spec.FinetuneJobs[i].Name = fmt.Sprintf("%s-%s-%d", finetuneExperiment.Name, "finetunejob", i+1)
- }
- finetuneJobInstance := &finetunev1beta1.FinetuneJob{}
- if err := r.Client.Get(ctx, types.NamespacedName{Name: finetuneExperiment.Spec.FinetuneJobs[i].Name, Namespace: finetuneExperiment.Namespace}, finetuneJobInstance); err != nil {
- r.Log.Errorf("Get finetuneJob %s/%s failed, err: %v", finetuneExperiment.Spec.FinetuneJobs[i].Name, finetuneExperiment.Namespace, err)
- return handlererr.HandlerErr(err)
- }
- if finetuneJobInstance.Status.FinetuneState == "" {
- finetuneJobInstance.Status.State = finetunev1beta1.FinetuneJobInit
- }
- if finetuneExperiment.Status.JobsStatus == nil {
- finetuneExperiment.Status.JobsStatus = make([]*finetunev1beta1.FinetuneJobStatusSetting, len(finetuneExperiment.Spec.FinetuneJobs))
+ success := true
+ for i := range finetuneExperiment.Spec.FinetuneJobs {
+ finetuneJobInstance := &finetunev1beta1.FinetuneJob{}
+ if err := r.Client.Get(ctx, types.NamespacedName{Name: finetuneExperiment.Spec.FinetuneJobs[i].Name, Namespace: finetuneExperiment.Namespace}, finetuneJobInstance); err != nil {
+ r.Log.Errorf("Get finetuneJob %s/%s failed, err: %v", finetuneExperiment.Spec.FinetuneJobs[i].Name, finetuneExperiment.Namespace, err)
+ return handlererr.HandlerErr(err)
+ }
+ if finetuneJobInstance.Status.FinetuneStatus == nil {
+ finetuneJobInstance.Status.FinetuneStatus = &finetunev1beta1.FinetuneStatus{
+ State: finetunev1beta1.FinetuneInit,
}
- if finetuneExperiment.Status.JobsStatus[i] != nil {
- if !reflect.DeepEqual(finetuneExperiment.Status.JobsStatus[i].FinetuneJobStatus, finetuneJobInstance.Status) {
- finetuneExperiment.Status.JobsStatus[i] = &finetunev1beta1.FinetuneJobStatusSetting{
- Name: finetuneJobInstance.Name,
- FinetuneJobStatus: finetuneJobInstance.Status,
- }
- }
- } else {
+ }
+
+ if finetuneExperiment.Status.JobsStatus == nil {
+ finetuneExperiment.Status.JobsStatus = make([]*finetunev1beta1.FinetuneJobStatusSetting, len(finetuneExperiment.Spec.FinetuneJobs))
+ }
+ if finetuneExperiment.Status.JobsStatus[i] != nil {
+ r.Log.Infof("Update finetuneExperiment %s/%s status", finetuneExperiment.Namespace, finetuneExperiment.Name)
+ if !reflect.DeepEqual(finetuneExperiment.Status.JobsStatus[i].FinetuneJobStatus, finetuneJobInstance.Status) {
finetuneExperiment.Status.JobsStatus[i] = &finetunev1beta1.FinetuneJobStatusSetting{
Name: finetuneJobInstance.Name,
FinetuneJobStatus: finetuneJobInstance.Status,
}
}
+ } else {
+ r.Log.Infof("Set finetuneExperiment %s/%s status", finetuneExperiment.Namespace, finetuneExperiment.Name)
+ finetuneExperiment.Status.JobsStatus[i] = &finetunev1beta1.FinetuneJobStatusSetting{
+ Name: finetuneJobInstance.Name,
+ FinetuneJobStatus: finetunev1beta1.FinetuneJobStatus{
+ State: finetunev1beta1.FinetuneJobInit,
+ FinetuneStatus: &finetunev1beta1.FinetuneStatus{
+ State: finetunev1beta1.FinetuneInit,
+ },
+ },
+ }
}
- if err := r.Client.Update(ctx, finetuneExperiment); err != nil {
- r.Log.Errorf("Update fineExperiment %s/%s failed", finetuneExperiment.Name, finetuneExperiment.Namespace)
- return handlererr.HandlerErr(err)
- }
- if err := r.Client.Status().Update(ctx, finetuneExperiment); err != nil {
- r.Log.Errorf("Update fineExperiment %s/%s status failed", finetuneExperiment.Name, finetuneExperiment.Namespace)
- return handlererr.HandlerErr(err)
+ if finetuneJobInstance.Status.State != finetunev1beta1.FinetuneJobSuccessful {
+ success = false
}
}
- if finetuneExperiment.Status.State == "" {
- finetuneExperiment.Status.State = finetunev1beta1.FinetuneExperimentProcessing
- if err := r.Client.Status().Update(ctx, finetuneExperiment); err != nil {
- r.Log.Errorf("Update fineExperiment %s/%s status failed", finetuneExperiment.Name, finetuneExperiment.Namespace)
- return handlererr.HandlerErr(err)
+
+ if success {
+ finetuneExperiment.Status.State = finetunev1beta1.FinetuneExperimentSuccess
+ jobs := finetuneExperiment.Status.JobsStatus
+ sort.Slice(jobs, func(i, j int) bool {
+ return util.ParseScore(jobs[i].FinetuneJobStatus.Result.Score) > util.ParseScore(jobs[j].FinetuneJobStatus.Result.Score)
+ })
+ finetuneJobBestVersion := &finetunev1beta1.FinetuneJob{}
+ if err := r.Client.Get(ctx, types.NamespacedName{Name: jobs[0].Name, Namespace: finetuneExperiment.Namespace}, finetuneJobBestVersion); err != nil {
+ r.Log.Errorf("Get finetuneJob %s/%s failed: %v", jobs[0].Name, finetuneExperiment.Namespace, err)
+ }
+ finetuneExperiment.Status.BestVersion = &finetunev1beta1.BestVersion{
+ Score: jobs[0].FinetuneJobStatus.Result.Score,
+ Image: jobs[0].FinetuneJobStatus.Result.Image,
+ LLM: finetuneJobBestVersion.Spec.FineTune.FinetuneSpec.LLM,
+ Hyperparameter: finetuneJobBestVersion.Spec.FineTune.FinetuneSpec.Hyperparameter,
+ Dataset: finetuneJobBestVersion.Spec.FineTune.FinetuneSpec.Dataset,
}
+ finetuneExperiment.Status.Stats = metav1.Now().Format("2006-01-02 15:04:05")
+ }
+
+ if err := r.Client.Status().Update(ctx, finetuneExperiment); err != nil {
+ r.Log.Errorf("Update fineExperiment %s/%s status failed", finetuneExperiment.Namespace, finetuneExperiment.Name)
+ return handlererr.HandlerErr(err)
}
return handlererr.HandlerErr(nil)
}
@@ -181,25 +223,24 @@ func (r *FinetuneExperimentReconciler) Reconcile(ctx context.Context, req ctrl.R
func (r *FinetuneExperimentReconciler) SetupWithManager(mgr ctrl.Manager) error {
return ctrl.NewControllerManagedBy(mgr).
For(&finetunev1beta1.FinetuneExperiment{}).
- Watches(&source.Kind{Type: &finetunev1beta1.FinetuneJob{}}, &handler.EnqueueRequestForOwner{
- OwnerType: &finetunev1beta1.FinetuneExperiment{},
- IsController: true,
- }, builder.WithPredicates(predicate.Funcs{
- UpdateFunc: func(updateEvent event.UpdateEvent) bool {
- oldFinetuneJob := updateEvent.ObjectOld.(*finetunev1beta1.FinetuneJob)
- newFinetuneJob := updateEvent.ObjectNew.(*finetunev1beta1.FinetuneJob)
- if oldFinetuneJob.Status.State != newFinetuneJob.Status.State {
- r.Log.Infof("Get finetuneJob %s/%s update event oldStatus: %s, newStatus: %s", oldFinetuneJob.Namespace, oldFinetuneJob.Name, oldFinetuneJob.Status.State, newFinetuneJob.Status.State)
- return true
- }
- return false
- },
- CreateFunc: func(createEvent event.CreateEvent) bool {
- finetuneJob := createEvent.Object.(*finetunev1beta1.FinetuneJob)
- r.Log.Infof("Get finetuneJob %s/%s crate event, skip", finetuneJob.Name, finetuneJob.Namespace)
- return false
- },
- })).
+ Watches(&finetunev1beta1.FinetuneJob{},
+ handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &finetunev1beta1.FinetuneExperiment{}, handler.OnlyControllerOwner()),
+ builder.WithPredicates(predicate.Funcs{
+ UpdateFunc: func(updateEvent event.UpdateEvent) bool {
+ oldFinetuneJob := updateEvent.ObjectOld.(*finetunev1beta1.FinetuneJob)
+ newFinetuneJob := updateEvent.ObjectNew.(*finetunev1beta1.FinetuneJob)
+ if oldFinetuneJob.Status.State != newFinetuneJob.Status.State {
+ r.Log.Infof("Get finetuneJob %s/%s update event oldStatus: %s, newStatus: %s", oldFinetuneJob.Namespace, oldFinetuneJob.Name, oldFinetuneJob.Status.State, newFinetuneJob.Status.State)
+ return true
+ }
+ return false
+ },
+ CreateFunc: func(createEvent event.CreateEvent) bool {
+ finetuneJob := createEvent.Object.(*finetunev1beta1.FinetuneJob)
+ r.Log.Infof("Get finetuneJob %s/%s crate event, skip", finetuneJob.Name, finetuneJob.Namespace)
+ return false
+ },
+ })).
WithOptions(controller.Options{
CacheSyncTimeout: 10 * time.Second,
MaxConcurrentReconciles: 1}).
diff --git a/internal/controller/finetune/finetunejob_controller.go b/internal/controller/finetune/finetunejob_controller.go
index 7503a79..2e343ab 100644
--- a/internal/controller/finetune/finetunejob_controller.go
+++ b/internal/controller/finetune/finetunejob_controller.go
@@ -31,6 +31,7 @@ import (
extensionv1beta1 "github.com/DataTunerX/meta-server/api/extension/v1beta1"
finetunev1beta1 "github.com/DataTunerX/meta-server/api/finetune/v1beta1"
"github.com/DataTunerX/utility-server/logging"
+ "github.com/duke-git/lancet/v2/slice"
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
batchv1 "k8s.io/api/batch/v1"
"k8s.io/apimachinery/pkg/api/errors"
@@ -45,7 +46,6 @@ import (
"sigs.k8s.io/controller-runtime/pkg/event"
"sigs.k8s.io/controller-runtime/pkg/handler"
"sigs.k8s.io/controller-runtime/pkg/predicate"
- "sigs.k8s.io/controller-runtime/pkg/source"
)
// FinetuneJobReconciler reconciles a FinetuneJob object
@@ -83,7 +83,10 @@ func (r *FinetuneJobReconciler) Reconcile(ctx context.Context, req ctrl.Request)
if finetuneJob.GetDeletionTimestamp() != nil {
r.Log.Infof("Delete finetuneJob: %s/%s", finetuneJob.Namespace, finetuneJob.Name)
if controllerutil.ContainsFinalizer(finetuneJob, finetunev1beta1.FinetuneGroupFinalizer) {
- // todo cleaner
+ if err := r.reconcileCleaner(ctx, finetuneJob); err != nil {
+ r.Log.Errorf("cleaner failed: %s/%s, Err: %v", finetuneJob.Namespace, finetuneJob.Name, err)
+ return handlererr.HandlerErr(err)
+ }
controllerutil.RemoveFinalizer(finetuneJob, finetunev1beta1.FinetuneGroupFinalizer)
if err := r.Update(ctx, finetuneJob); err != nil {
r.Log.Errorf("Remove finalizer failed: %s/%s, Err: %v", finetuneJob.Namespace, finetuneJob.Name, err)
@@ -113,8 +116,6 @@ func (r *FinetuneJobReconciler) Reconcile(ctx context.Context, req ctrl.Request)
return handlererr.HandlerErr(err)
}
- // Phase II of the fine-tuning exercise.
- // Generate finetune CR.
existFinetune, err := r.reconcileFinetuneSend(ctx, finetuneJob)
if err != nil {
return handlererr.HandlerErr(err)
@@ -132,31 +133,10 @@ func (r *FinetuneJobReconciler) Reconcile(ctx context.Context, req ctrl.Request)
return handlererr.HandlerErr(err)
}
- scoringName := fmt.Sprintf("%s-scoring", finetuneJob.Name)
- scoring := &extensionv1beta1.Scoring{}
- if err := r.Get(ctx, types.NamespacedName{
- Name: scoringName,
- Namespace: finetuneJob.Namespace,
- }, scoring); err != nil {
- if errors.IsNotFound(err) {
- r.Log.Infof("Scoring %s/%s not found, err: %v", scoringName, finetuneJob.Namespace, err)
- return ctrl.Result{RequeueAfter: 30 * time.Second}, nil
- }
- r.Log.Errorf("Get scoring %s/%s failed: %v", scoringName, finetuneJob.Namespace, err)
+ if err := r.reconcileByScoringStatus(ctx, finetuneJob); err != nil {
return handlererr.HandlerErr(err)
}
- // todo(tigerK) get scoring result, update finetuneJob status
- if scoring.Status.Score != nil {
- finetuneJob.Status.State = finetunev1beta1.FinetuneJobSuccessful
- finetuneJob.Status.Result.Score = *scoring.Status.Score
- finetuneJob.Status.Stats = metav1.Now().Format("2006-01-02 15:04:05")
- if err := r.Client.Status().Update(ctx, finetuneJob); err != nil {
- r.Log.Errorf("Update finetuneJob status failed: %v", err)
- return handlererr.HandlerErr(err)
- }
- }
-
// Phase IIII of the fine-tuning exercise.
// Check finetune cr status, if finetune cr status is SUCCESSFUL, start next
return handlererr.HandlerErr(nil)
@@ -179,55 +159,51 @@ func (r *FinetuneJobReconciler) SetupWithManager(mgr ctrl.Manager) error {
return false
},
})).
- Watches(&source.Kind{Type: &finetunev1beta1.Finetune{}}, &handler.EnqueueRequestForOwner{
- OwnerType: &finetunev1beta1.FinetuneJob{},
- IsController: true,
- }, builder.WithPredicates(predicate.Funcs{
- UpdateFunc: func(updateEvent event.UpdateEvent) bool {
- oldFinetune := updateEvent.ObjectOld.(*finetunev1beta1.Finetune)
- newFinetune := updateEvent.ObjectNew.(*finetunev1beta1.Finetune)
- if oldFinetune.Status.State != newFinetune.Status.State {
- r.Log.Infof("Get finetun %s/%s update event oldStatus: %s, newStatus: %s", oldFinetune.Name, oldFinetune.Namespace, oldFinetune.Status.State, newFinetune.Status.State)
+ Watches(&finetunev1beta1.Finetune{},
+ handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &finetunev1beta1.FinetuneJob{}, handler.OnlyControllerOwner()),
+ builder.WithPredicates(predicate.Funcs{
+ UpdateFunc: func(updateEvent event.UpdateEvent) bool {
+ oldFinetune := updateEvent.ObjectOld.(*finetunev1beta1.Finetune)
+ newFinetune := updateEvent.ObjectNew.(*finetunev1beta1.Finetune)
+ if oldFinetune.Status.State != newFinetune.Status.State {
+ r.Log.Infof("Get finetun %s/%s update event oldStatus: %s, newStatus: %s", oldFinetune.Name, oldFinetune.Namespace, oldFinetune.Status.State, newFinetune.Status.State)
+ return true
+ }
+ return false
+ },
+ CreateFunc: func(createEvent event.CreateEvent) bool {
+ finetune := createEvent.Object.(*finetunev1beta1.Finetune)
+ r.Log.Infof("Get finetun %s/%s crate event, skip", finetune.Name, finetune.Namespace)
+ return false
+ },
+ })).
+ Watches(&batchv1.Job{},
+ handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &finetunev1beta1.FinetuneJob{}, handler.OnlyControllerOwner()),
+ builder.WithPredicates(predicate.NewPredicateFuncs(func(object client.Object) bool {
+ job := object.(*batchv1.Job)
+ if job.Status.CompletionTime != nil {
return true
}
return false
- },
- CreateFunc: func(createEvent event.CreateEvent) bool {
- finetune := createEvent.Object.(*finetunev1beta1.Finetune)
- r.Log.Infof("Get finetun %s/%s crate event, skip", finetune.Name, finetune.Namespace)
+ }))).
+ Watches(&rayv1.RayService{},
+ handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &finetunev1beta1.FinetuneJob{}, handler.OnlyControllerOwner()),
+ builder.WithPredicates(predicate.NewPredicateFuncs(func(object client.Object) bool {
+ rayService := object.(*rayv1.RayService)
+ if rayService.Status.ServiceStatus == rayv1.Running {
+ return true
+ }
return false
- },
- })).
- Watches(&source.Kind{Type: &batchv1.Job{}}, &handler.EnqueueRequestForOwner{
- OwnerType: &finetunev1beta1.FinetuneJob{},
- IsController: true,
- }, builder.WithPredicates(predicate.NewPredicateFuncs(func(object client.Object) bool {
- job := object.(*batchv1.Job)
- if job.Status.CompletionTime != nil {
- return true
- }
- return false
- }))).
- Watches(&source.Kind{Type: &rayv1.RayService{}}, &handler.EnqueueRequestForOwner{
- OwnerType: &finetunev1beta1.FinetuneJob{},
- IsController: true,
- }, builder.WithPredicates(predicate.NewPredicateFuncs(func(object client.Object) bool {
- rayService := object.(*rayv1.RayService)
- if rayService.Status.ServiceStatus == rayv1.Running {
- return true
- }
- return false
- }))).
- Watches(&source.Kind{Type: &extensionv1beta1.Scoring{}}, &handler.EnqueueRequestForOwner{
- OwnerType: &finetunev1beta1.FinetuneJob{},
- IsController: true,
- }, builder.WithPredicates(predicate.NewPredicateFuncs(func(object client.Object) bool {
- scoring := object.(*extensionv1beta1.Scoring)
- if scoring.Status.Score != nil {
- return true
- }
- return false
- }))).
+ }))).
+ Watches(&extensionv1beta1.Scoring{},
+ handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &finetunev1beta1.FinetuneJob{}, handler.OnlyControllerOwner()),
+ builder.WithPredicates(predicate.NewPredicateFuncs(func(object client.Object) bool {
+ scoring := object.(*extensionv1beta1.Scoring)
+ if scoring.Status.Score != nil {
+ return true
+ }
+ return false
+ }))).
WithOptions(controller.Options{
CacheSyncTimeout: 10 * time.Second,
MaxConcurrentReconciles: 1}).
@@ -244,11 +220,44 @@ func (r *FinetuneJobReconciler) reconcilePreCondition(ctx context.Context, finet
r.Log.Errorf("Get %s: %s/%s failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err)
return err
}
+ switch obj.(type) {
+ case *corev1beta1.LLM:
+ llm := obj.(*corev1beta1.LLM)
+ if llm.Status.ReferenceFinetuneName == nil {
+ llm.Status.ReferenceFinetuneName = make([]string, 0)
+ }
+ llm.Status.ReferenceFinetuneName = slice.AppendIfAbsent(llm.Status.ReferenceFinetuneName, finetuneJob.Spec.FineTune.Name)
+ if err := r.Client.Status().Update(ctx, llm); err != nil {
+ r.Log.Errorf("update %s: %s/%s status failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err)
+ return err
+ }
+ case *extensionv1beta1.Dataset:
+ dataset := obj.(*extensionv1beta1.Dataset)
+ if dataset.Status.ReferenceFinetuneName == nil {
+ dataset.Status.ReferenceFinetuneName = make([]string, 0)
+ }
+ dataset.Status.ReferenceFinetuneName = slice.AppendIfAbsent(dataset.Status.ReferenceFinetuneName, finetuneJob.Spec.FineTune.Name)
+ if err := r.Client.Status().Update(ctx, dataset); err != nil {
+ r.Log.Errorf("update %s: %s/%s status failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err)
+ return err
+ }
+ case *corev1beta1.Hyperparameter:
+ hyperparameter := obj.(*corev1beta1.Hyperparameter)
+ if hyperparameter.Status.ReferenceFinetuneName == nil {
+ hyperparameter.Status.ReferenceFinetuneName = make([]string, 0)
+ }
+ hyperparameter.Status.ReferenceFinetuneName = slice.AppendIfAbsent(hyperparameter.Status.ReferenceFinetuneName, finetuneJob.Spec.FineTune.Name)
+ if err := r.Client.Status().Update(ctx, hyperparameter); err != nil {
+ r.Log.Errorf("update %s: %s/%s status failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err)
+ return err
+ }
+ }
}
return nil
}
func (r *FinetuneJobReconciler) reconcileFinetuneSend(ctx context.Context, finetuneJob *finetunev1beta1.FinetuneJob) (*finetunev1beta1.Finetune, error) {
+
specFinetuneInstance := generate.GenerateFinetune(finetuneJob)
existFinetuneInstance := &finetunev1beta1.Finetune{}
if err := r.Get(ctx, types.NamespacedName{
@@ -267,7 +276,7 @@ func (r *FinetuneJobReconciler) reconcileFinetuneSend(ctx context.Context, finet
return nil, err
}
}
- return nil, fmt.Errorf("Finetune %s/%s creating, reReconcile", specFinetuneInstance.Namespace, specFinetuneInstance.Name)
+ return nil, valueobject.ErrRecalibrate
}
}
return existFinetuneInstance, nil
@@ -278,7 +287,7 @@ func (r *FinetuneJobReconciler) reconcileByFinetuneStatus(ctx context.Context, f
if finetuneInstance.Status.State == finetunev1beta1.FinetuneInit || finetuneInstance.Status.State == finetunev1beta1.FinetuneRunning {
r.Log.Infof("Update finetuneJob %s/%s status %s.", finetuneJobInstance.Namespace, finetuneJobInstance.Name, finetunev1beta1.FinetuneJobFinetune)
finetuneJobInstance.Status.State = finetunev1beta1.FinetuneJobFinetune
- finetuneJobInstance.Status.FinetuneState = finetuneInstance.Status.State
+ finetuneJobInstance.Status.FinetuneStatus = &finetuneInstance.Status
if err := r.Client.Status().Update(ctx, finetuneJobInstance); err != nil {
r.Log.Errorf("Update finetuneJob %s/%s status failed: %v", finetuneJobInstance.Namespace, finetuneJobInstance.Name, err)
return err
@@ -297,22 +306,27 @@ func (r *FinetuneJobReconciler) reconcileByFinetuneStatus(ctx context.Context, f
return err
}
// build llmCheckpoint image server. job
- buildImageJobName := fmt.Sprintf("%s-buildimage", finetuneJobInstance.Name)
+
+ imageName := fmt.Sprintf("ray271-llama2-7b-finetune-checkpoint-%s", finetuneJobInstance.Name)
+ imageTag := fmt.Sprintf("%s", time.Now().Format("20060102"))
checkPointFilePath := finetuneInstance.Status.LLMCheckpoint.CheckpointPath
checkPointFilePath = util.RemoveBucketName(checkPointFilePath, config.GetS3Bucket())
- buildImageJob := generate.GenerateBuildImageJob(buildImageJobName, finetuneJobInstance.Namespace, checkPointFilePath)
+ buildImageJob := generate.GenerateBuildImageJob(checkPointFilePath, imageName, imageTag, finetuneJobInstance)
if err := ctrl.SetControllerReference(finetuneJobInstance, buildImageJob, r.Scheme); err != nil {
r.Log.Errorf("Set owner failed: %v", err)
return err
}
- if err := r.Client.Create(ctx, buildImageJob); err != nil {
- if !errors.IsAlreadyExists(err) {
- r.Log.Errorf("Create job %s/%s failed, err: %v", buildImageJob.Name, buildImageJob.Namespace, err)
- return err
+ if err := r.Client.Get(ctx, types.NamespacedName{Name: buildImageJob.Name, Namespace: buildImageJob.Namespace}, buildImageJob); err != nil {
+ if errors.IsNotFound(err) {
+ if err := r.Client.Create(ctx, buildImageJob); err != nil {
+ r.Log.Errorf("Create job %s/%s failed, err: %v", buildImageJob.Name, buildImageJob.Namespace, err)
+ return err
+ }
}
}
+
llmCheckpoint.Spec.CheckpointImage = &corev1beta1.CheckpointImage{}
- llmImage := fmt.Sprintf("%s/%s/%s:%s", config.GetRegistryUrl(), config.GetRepositoryName(), config.GetImageName(), config.GetImageTag())
+ llmImage := fmt.Sprintf("%s/%s/%s:%s", config.GetRegistryUrl(), config.GetRepositoryName(), imageName, imageTag)
llmCheckpoint.Spec.CheckpointImage.Name = &llmImage
llmCheckpoint.Spec.CheckpointImage.CheckPointPath = fmt.Sprintf("/checkpoint/%s", checkPointFilePath)
llmCheckpoint.Spec.CheckpointImage.LLMPath = llmCheckpoint.Spec.Image.Path
@@ -322,7 +336,7 @@ func (r *FinetuneJobReconciler) reconcileByFinetuneStatus(ctx context.Context, f
}
finetuneJobInstance.Status.State = finetunev1beta1.FinetuneJobBuildImage
- finetuneJobInstance.Status.FinetuneState = finetuneInstance.Status.State
+ finetuneJobInstance.Status.FinetuneStatus = &finetuneInstance.Status
if err := r.Client.Status().Update(ctx, finetuneJobInstance); err != nil {
r.Log.Errorf("Update finetuneJob %s/%s status failed: %v", finetuneJobInstance.Namespace, finetuneInstance.Name, err)
return err
@@ -331,7 +345,7 @@ func (r *FinetuneJobReconciler) reconcileByFinetuneStatus(ctx context.Context, f
if finetuneInstance.Status.State == finetunev1beta1.FinetuneFailed {
finetuneJobInstance.Status.State = finetunev1beta1.FinetuneJobFailed
- finetuneJobInstance.Status.FinetuneState = finetuneInstance.Status.State
+ finetuneJobInstance.Status.FinetuneStatus = &finetuneInstance.Status
if err := r.Client.Status().Update(ctx, finetuneJobInstance); err != nil {
r.Log.Errorf("Update finetuneJob %s/%s status failed: %v", finetuneJobInstance.Namespace, finetuneInstance.Name, err)
return err
@@ -372,15 +386,18 @@ func (r *FinetuneJobReconciler) reconcileByJobStatus(ctx context.Context, finetu
r.Log.Errorf("Set owner failed: %v", err)
return err
}
- if err := r.Create(ctx, rayService); err != nil {
- if !errors.IsAlreadyExists(err) {
- r.Log.Errorf("Create rayService %s/%s failed: %v", rayServiceName, finetuneJob.Namespace, err)
- return err
+
+ if err := r.Client.Get(ctx, types.NamespacedName{Name: rayServiceName, Namespace: finetuneJob.Namespace}, rayService); err != nil {
+ if errors.IsNotFound(err) {
+ if err := r.Create(ctx, rayService); err != nil {
+ r.Log.Errorf("Create rayService %s/%s failed: %v", rayServiceName, finetuneJob.Namespace, err)
+ return err
+ }
}
}
- r.Log.Infof("Send serve successful")
+
finetuneJob.Status.State = finetunev1beta1.FinetuneJobServe
- finetuneJob.Status.FinetuneState = finetune.Status.State
+ finetuneJob.Status.FinetuneStatus = &finetune.Status
finetuneJob.Status.Result = &finetunev1beta1.FinetuneJobResult{
ModelExportResult: true,
Image: *llmCheckpoint.Spec.CheckpointImage.Name,
@@ -404,17 +421,23 @@ func (r *FinetuneJobReconciler) reconcileByRayServiceStatus(ctx context.Context,
return err
}
if finetuneJob.Status.State == finetunev1beta1.FinetuneJobServe && rayService.Status.ServiceStatus == rayv1.Running {
- //serveNodePort := rayService.Status.ActiveServiceStatus.RayClusterStatus.Endpoints["serve"]
- //dashboardNodePort := rayService.Status.ActiveServiceStatus.RayClusterStatus.Endpoints["dashboard"]
- finetuneJob.Status.Result.Serve = fmt.Sprintf("%s.%s.svc:%s", finetuneJob.Name, finetuneJob.Namespace, "8000")
- finetuneJob.Status.Result.Dashboard = fmt.Sprintf("%s.%s.svc:%s", finetuneJob.Name, finetuneJob.Namespace, "8265")
+ if rayService.Status.ActiveServiceStatus.Applications["default"].Deployments["LlamaDeployment"].Status == "HEALTHY" {
+ // todo(tigerK) no time for optimisation
+ //serveNodePort := rayService.Status.ActiveServiceStatus.RayClusterStatus.Endpoints["serve"]
+ //dashboardNodePort := rayService.Status.ActiveServiceStatus.RayClusterStatus.Endpoints["dashboard"]
+ finetuneJob.Status.Result.Serve = fmt.Sprintf("%s.%s.svc:%s", finetuneJob.Name, finetuneJob.Namespace, "8000")
+ finetuneJob.Status.Result.Dashboard = fmt.Sprintf("%s.%s.svc:%s", finetuneJob.Name, finetuneJob.Namespace, "8080")
+ } else {
+ return valueobject.ErrRecalibrate
+ }
+ infrencePath := fmt.Sprintf("http://%s/chat/completions", finetuneJob.Status.Result.Serve)
if err := r.Client.Status().Update(ctx, finetuneJob); err != nil {
r.Log.Errorf("Update finetuneJob status failed: %v", err)
return err
}
scoringName := fmt.Sprintf("%s-scoring", finetuneJob.Name)
- if finetuneJob.Spec.ScoringConfig == nil {
- scoring := generate.GenerateBuiltInScoring(scoringName, finetuneJob.Namespace, finetuneJob.Status.Result.Serve)
+ if finetuneJob.Spec.ScoringPluginConfig == nil {
+ scoring := generate.GenerateBuiltInScoring(scoringName, finetuneJob.Namespace, infrencePath)
if err := ctrl.SetControllerReference(finetuneJob, scoring, r.Scheme); err != nil {
r.Log.Errorf("Set owner failed: %v", err)
return err
@@ -427,7 +450,7 @@ func (r *FinetuneJobReconciler) reconcileByRayServiceStatus(ctx context.Context,
}
return nil
}
- scoring := generate.GeneratePluginScoring(scoringName, finetuneJob.Namespace, finetuneJob.Spec.ScoringConfig.Name, finetuneJob.Spec.ScoringConfig.Parameters, finetuneJob.Status.Result.Serve)
+ scoring := generate.GeneratePluginScoring(scoringName, finetuneJob.Namespace, finetuneJob.Spec.ScoringPluginConfig.Name, finetuneJob.Spec.ScoringPluginConfig.Parameters, infrencePath)
if err := ctrl.SetControllerReference(finetuneJob, scoring, r.Scheme); err != nil {
r.Log.Errorf("Set owner failed: %v", err)
return err
@@ -441,3 +464,97 @@ func (r *FinetuneJobReconciler) reconcileByRayServiceStatus(ctx context.Context,
}
return nil
}
+
+func (r *FinetuneJobReconciler) reconcileByScoringStatus(ctx context.Context, finetuneJob *finetunev1beta1.FinetuneJob) error {
+
+ scoringName := fmt.Sprintf("%s-scoring", finetuneJob.Name)
+ scoring := &extensionv1beta1.Scoring{}
+ if err := r.Get(ctx, types.NamespacedName{
+ Name: scoringName,
+ Namespace: finetuneJob.Namespace,
+ }, scoring); err != nil {
+ if errors.IsNotFound(err) {
+ r.Log.Infof("Scoring %s/%s not found, err: %v", scoringName, finetuneJob.Namespace, err)
+ return valueobject.ErrRecalibrate
+ }
+ r.Log.Errorf("Get scoring %s/%s failed: %v", scoringName, finetuneJob.Namespace, err)
+ return err
+ }
+
+ // todo(tigerK) get scoring result, update finetuneJob status
+ if scoring.Status.Score != nil {
+ finetuneJob.Status.State = finetunev1beta1.FinetuneJobSuccessful
+ finetuneJob.Status.Result.Score = *scoring.Status.Score
+ finetuneJob.Status.Stats = metav1.Now().Format("2006-01-02 15:04:05")
+ if err := r.Client.Status().Update(ctx, finetuneJob); err != nil {
+ r.Log.Errorf("Update finetuneJob status failed: %v", err)
+ return err
+ }
+ rayServiceName := fmt.Sprintf("%s", finetuneJob.Name)
+ rayService := &rayv1.RayService{}
+ if err := r.Get(ctx, types.NamespacedName{
+ Name: rayServiceName,
+ Namespace: finetuneJob.Namespace,
+ }, rayService); err != nil {
+ if errors.IsNotFound(err) {
+ return nil
+ }
+ r.Log.Errorf("Get rayService %s/%s failed: %v", finetuneJob.Namespace, rayServiceName, err)
+ return err
+ }
+ if err := r.Delete(ctx, rayService); err != nil {
+ r.Log.Errorf("Delete rayService %s/%s failed: %v", finetuneJob.Namespace, rayServiceName, err)
+ return err
+ }
+ }
+ return nil
+}
+
+func (r *FinetuneJobReconciler) reconcileCleaner(ctx context.Context, finetuneJob *finetunev1beta1.FinetuneJob) error {
+ preCondition := make(map[string]client.Object, 3)
+ preCondition[finetuneJob.Spec.FineTune.FinetuneSpec.LLM] = &corev1beta1.LLM{}
+ preCondition[finetuneJob.Spec.FineTune.FinetuneSpec.Hyperparameter.HyperparameterRef] = &corev1beta1.Hyperparameter{}
+ preCondition[finetuneJob.Spec.FineTune.FinetuneSpec.Dataset] = &extensionv1beta1.Dataset{}
+ for name, obj := range preCondition {
+ if err := r.Get(ctx, types.NamespacedName{Name: name, Namespace: finetuneJob.Namespace}, obj); err != nil {
+ r.Log.Errorf("Get %s: %s/%s failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err)
+ return err
+ }
+ switch obj.(type) {
+ case *corev1beta1.LLM:
+ llm := obj.(*corev1beta1.LLM)
+ if llm.Status.ReferenceFinetuneName == nil {
+ continue
+ }
+ result := slice.IndexOf(llm.Status.ReferenceFinetuneName, finetuneJob.Spec.FineTune.Name)
+ llm.Status.ReferenceFinetuneName = slice.DeleteAt(llm.Status.ReferenceFinetuneName, result)
+ if err := r.Client.Status().Update(ctx, llm); err != nil {
+ r.Log.Errorf("update %s: %s/%s status failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err)
+ return err
+ }
+ case *extensionv1beta1.Dataset:
+ dataset := obj.(*extensionv1beta1.Dataset)
+ if dataset.Status.ReferenceFinetuneName == nil {
+ continue
+ }
+ result := slice.IndexOf(dataset.Status.ReferenceFinetuneName, finetuneJob.Spec.FineTune.Name)
+ dataset.Status.ReferenceFinetuneName = slice.DeleteAt(dataset.Status.ReferenceFinetuneName, result)
+ if err := r.Client.Status().Update(ctx, dataset); err != nil {
+ r.Log.Errorf("update %s: %s/%s status failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err)
+ return err
+ }
+ case *corev1beta1.Hyperparameter:
+ hyperparameter := obj.(*corev1beta1.Hyperparameter)
+ if hyperparameter.Status.ReferenceFinetuneName == nil {
+ continue
+ }
+ result := slice.IndexOf(hyperparameter.Status.ReferenceFinetuneName, finetuneJob.Spec.FineTune.Name)
+ hyperparameter.Status.ReferenceFinetuneName = slice.DeleteAt(hyperparameter.Status.ReferenceFinetuneName, result)
+ if err := r.Client.Status().Update(ctx, hyperparameter); err != nil {
+ r.Log.Errorf("update %s: %s/%s status failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err)
+ return err
+ }
+ }
+ }
+ return nil
+}
diff --git a/main.go b/main.go
index f81b479..0278ac0 100644
--- a/main.go
+++ b/main.go
@@ -19,10 +19,9 @@ package main
import (
"os"
+ "github.com/DataTunerX/finetune-experiment-controller/cmd/controller-manager/app"
"github.com/DataTunerX/finetune-experiment-controller/pkg/config"
"github.com/DataTunerX/utility-server/logging"
-
- "github.com/DataTunerX/finetune-experiment-controller/cmd/controller-manager/app"
ctrl "sigs.k8s.io/controller-runtime"
)
diff --git a/pkg/config/config.go b/pkg/config/config.go
index deaf915..5498a6e 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -18,9 +18,10 @@ func init() {
config.BindEnv("repositoryName", "REPOSITORY_NAME")
config.BindEnv("userName", "USERNAME")
config.BindEnv("password", "PASSWORD")
- config.BindEnv("imageName", "IMAGE_NAME")
- config.BindEnv("imageTag", "IMAGE_TAG")
config.BindEnv("mountPath", "MOUNT_PATH")
+ config.BindEnv("baseImage", "BASE_IMAGE")
+ config.BindEnv("llmUrl", "LLM_URL")
+ config.SetDefault("llmUrl", "/tmp/llama2-7b/")
}
func GetS3Endpoint() string {
@@ -51,16 +52,12 @@ func GetUserName() string {
return config.GetString("userName")
}
-func GetPassword() string {
- return config.GetString("password")
-}
-
-func GetImageName() string {
- return config.GetString("imageName")
+func GetBaseImage() string {
+ return config.GetString("baseImage")
}
-func GetImageTag() string {
- return config.GetString("imageTag")
+func GetPassword() string {
+ return config.GetString("password")
}
func GetRegistryUrl() string {
@@ -74,3 +71,7 @@ func GetRepositoryName() string {
func GetMountPath() string {
return config.GetString("mountPath")
}
+
+func GetLLMUrl() string {
+ return config.GetString("llmUrl")
+}
diff --git a/pkg/tuning/Dockerfile b/pkg/tuning/Dockerfile
new file mode 100644
index 0000000..786fbfe
--- /dev/null
+++ b/pkg/tuning/Dockerfile
@@ -0,0 +1,8 @@
+FROM rayproject/ray271-py39-gpu-llama2-7b-inference:20231220
+
+WORKDIR /tuning
+
+COPY requirements.txt .
+RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
+
+COPY . .
\ No newline at end of file
diff --git a/pkg/tuning/README.md b/pkg/tuning/README.md
new file mode 100644
index 0000000..48dbd0e
--- /dev/null
+++ b/pkg/tuning/README.md
@@ -0,0 +1,3 @@
+# Dataset
+instruction
+response
\ No newline at end of file
diff --git a/pkg/tuning/build_image.sh b/pkg/tuning/build_image.sh
new file mode 100644
index 0000000..26d3be0
--- /dev/null
+++ b/pkg/tuning/build_image.sh
@@ -0,0 +1 @@
+docker build . -t rayproject/ray271-llama2-7b-finetune:20231220
\ No newline at end of file
diff --git a/pkg/tuning/callback.py b/pkg/tuning/callback.py
new file mode 100644
index 0000000..cedf15a
--- /dev/null
+++ b/pkg/tuning/callback.py
@@ -0,0 +1,166 @@
+import os
+import json
+import time
+from typing import TYPE_CHECKING
+from datetime import timedelta
+
+from transformers import TrainerCallback
+from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
+
+from prometheus.metrics import export_train_metrics, export_eval_metrics
+
+if TYPE_CHECKING:
+ from transformers import TrainingArguments, TrainerState, TrainerControl
+
+LOG_FILE_NAME = "trainer_log.jsonl"
+
+
+class LogCallback(TrainerCallback):
+
+ def __init__(self, runner=None, metrics_export_address=None, uid=None):
+ self.runner = runner
+ self.in_training = False
+ self.start_time = time.time()
+ self.cur_steps = 0
+ self.max_steps = 0
+ self.elapsed_time = ""
+ self.remaining_time = ""
+ self.metrics_export_address = metrics_export_address
+ self.uid = uid
+
+ def timing(self):
+ cur_time = time.time()
+ elapsed_time = cur_time - self.start_time
+ avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
+ remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
+ self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
+ self.remaining_time = str(timedelta(seconds=int(remaining_time)))
+
+ def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the beginning of training.
+ """
+ if state.is_local_process_zero:
+ self.in_training = True
+ self.start_time = time.time()
+ self.max_steps = state.max_steps
+ if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
+ print("Previous log file in this folder will be deleted.")
+ os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
+
+ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of training.
+ """
+ if state.is_local_process_zero:
+ self.in_training = False
+ self.cur_steps = 0
+ self.max_steps = 0
+
+ def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of an substep during gradient accumulation.
+ """
+ if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
+ control.should_epoch_stop = True
+ control.should_training_stop = True
+
+ def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of a training step.
+ """
+ if state.is_local_process_zero:
+ self.cur_steps = state.global_step
+ self.timing()
+ if self.runner is not None and self.runner.aborted:
+ control.should_epoch_stop = True
+ control.should_training_stop = True
+
+ def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called after an evaluation phase.
+ """
+ if state.is_local_process_zero and not self.in_training:
+ self.cur_steps = 0
+ self.max_steps = 0
+
+ def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
+ r"""
+ Event called after a successful prediction.
+ """
+ if state.is_local_process_zero and not self.in_training:
+ self.cur_steps = 0
+ self.max_steps = 0
+
+ def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
+ r"""
+ Event called after logging the last logs.
+ """
+ if not state.is_local_process_zero:
+ return
+
+ print('log_history: ', state.log_history[-1]) # add 看看返回的 key
+ if "eval_loss" in state.log_history[-1].keys():
+ eval_log = dict(
+ uid=self.uid,
+ current_steps=self.cur_steps,
+ total_steps=self.max_steps,
+ eval_loss=state.log_history[-1].get("eval_loss", None),
+ eval_perplexity=state.log_history[-1].get("eval_perplexity", None),
+ eval_rouge_1=state.log_history[-1].get("eval_rouge-1", None),
+ eval_rouge_2=state.log_history[-1].get("eval_rouge-2", None),
+ eval_rouge_l=state.log_history[-1].get("eval_rouge-l", None),
+ eval_bleu_4=state.log_history[-1].get("eval_bleu-4", None),
+ epoch=state.log_history[-1].get("epoch", None),
+ percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
+ elapsed_time=self.elapsed_time,
+ remaining_time=self.remaining_time
+ )
+ else:
+ logs = dict(
+ uid=self.uid,
+ current_steps=self.cur_steps,
+ total_steps=self.max_steps,
+ loss=state.log_history[-1].get("loss", None),
+ eval_loss=state.log_history[-1].get("eval_loss", None),
+ val_perplexity=state.log_history[-1].get("eval_perplexity", None),
+ eval_rouge_1=state.log_history[-1].get("eval_rouge-1", None),
+ eval_rouge_2=state.log_history[-1].get("eval_rouge-2", None),
+ eval_rouge_l=state.log_history[-1].get("eval_rouge-l", None),
+ eval_bleu_4=state.log_history[-1].get("eval_bleu-4", None),
+ predict_loss=state.log_history[-1].get("predict_loss", None),
+ reward=state.log_history[-1].get("reward", None),
+ learning_rate=state.log_history[-1].get("learning_rate", None),
+ epoch=state.log_history[-1].get("epoch", None),
+ percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
+ elapsed_time=self.elapsed_time,
+ remaining_time=self.remaining_time
+ )
+ if self.runner is not None:
+ print("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
+ logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
+ ))
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ os.makedirs(os.path.join(args.output_dir, 'watch'), exist_ok=True)
+ if "eval_loss" in state.log_history[-1].keys():
+ with open(os.path.join(args.output_dir, 'watch', "eval_log.jsonl"), "a", encoding="utf-8") as f:
+ f.write(json.dumps(eval_log) + "\n")
+ if self.metrics_export_address:
+ export_eval_metrics(self.metrics_export_address, eval_log)
+ else:
+ with open(os.path.join(args.output_dir, 'watch', "trainer_log.jsonl"), "a", encoding="utf-8") as f:
+ f.write(json.dumps(logs) + "\n")
+ if self.metrics_export_address:
+ export_train_metrics(self.metrics_export_address, logs)
+
+ def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called after a prediction step.
+ """
+ eval_dataloader = kwargs.pop("eval_dataloader", None)
+ if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
+ if self.max_steps == 0:
+ self.max_steps = len(eval_dataloader)
+ self.cur_steps += 1
+ self.timing()
diff --git a/pkg/tuning/ds_config.json b/pkg/tuning/ds_config.json
new file mode 100644
index 0000000..5cf1c15
--- /dev/null
+++ b/pkg/tuning/ds_config.json
@@ -0,0 +1,13 @@
+{
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "zero_allow_untested_optimizer": true,
+ "fp16": {
+ "enabled": "auto"
+ },
+ "zero_optimization": {
+ "stage": 0
+ }
+}
\ No newline at end of file
diff --git a/pkg/tuning/parser.py b/pkg/tuning/parser.py
new file mode 100644
index 0000000..a464712
--- /dev/null
+++ b/pkg/tuning/parser.py
@@ -0,0 +1,266 @@
+import json
+import logging
+from dataclasses import field, dataclass
+from typing import Optional, Dict, Any, Tuple, Literal
+
+import transformers
+from transformers import Seq2SeqTrainingArguments, HfArgumentParser
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ModelArguments:
+ r"""
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
+ """
+ model_name_or_path: str = field(
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
+ )
+ cache_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
+ )
+ use_fast_tokenizer: Optional[bool] = field(
+ default=True,
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
+ )
+ split_special_tokens: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
+ )
+ use_auth_token: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
+ )
+ model_revision: Optional[str] = field(
+ default="main",
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
+ )
+ quantization_bit: Optional[int] = field(
+ default=None,
+ metadata={"help": "The number of bits to quantize the model."}
+ )
+ quantization_type: Optional[Literal["fp4", "nf4"]] = field(
+ default="nf4",
+ metadata={"help": "Quantization data type to use in int4 training."}
+ )
+ double_quantization: Optional[bool] = field(
+ default=True,
+ metadata={"help": "Whether to use double quantization in int4 training or not."}
+ )
+ quantization: Optional[str] = field(
+ default=None,
+ metadata={"help": "quantize the model, int4, int8, or None."}
+ )
+
+ rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
+ default=None,
+ metadata={"help": "Adopt scaled rotary positional embeddings."}
+ )
+ checkpoint_dir: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
+ )
+ flash_attn: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Enable FlashAttention-2 for faster training."}
+ )
+ shift_attn: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
+ )
+ reward_model: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
+ )
+ plot_loss: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
+ )
+ hf_auth_token: Optional[str] = field(
+ default=None,
+ metadata={"help": "Auth token to log in with Hugging Face Hub."}
+ )
+ export_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the directory to save the exported model."}
+ )
+
+ def __post_init__(self):
+ self.compute_dtype = None
+ self.model_max_length = None
+
+ if self.split_special_tokens and self.use_fast_tokenizer:
+ raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
+
+ if self.checkpoint_dir is not None: # support merging multiple lora weights
+ self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
+
+ if self.quantization_bit is not None:
+ assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
+
+ if self.quantization is not None:
+ assert self.quantization in ["int4", "int8"], "We only accept int4 or int8 quantization."
+
+ if self.use_auth_token == True and self.hf_auth_token is not None:
+ from huggingface_hub.hf_api import HfFolder # lazy load
+ HfFolder.save_token(self.hf_auth_token)
+
+
+@dataclass
+class FinetuningArguments:
+ r"""
+ Arguments pertaining to which techniques we are going to fine-tuning with.
+ """
+ stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
+ default="sft",
+ metadata={"help": "Which stage will be performed in training."}
+ )
+ finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
+ default="lora",
+ metadata={"help": "Which fine-tuning method to use."}
+ )
+ num_layer_trainable: Optional[int] = field(
+ default=3,
+ metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
+ )
+ name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
+ default="mlp",
+ metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
+ LLaMA choices: [\"mlp\", \"self_attn\"], \
+ BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \
+ Qwen choices: [\"mlp\", \"attn\"], \
+ Phi-1.5 choices: [\"mlp\", \"mixer\"], \
+ LLaMA-2, Baichuan, InternLM, XVERSE choices: the same as LLaMA."}
+ )
+ lora_rank: Optional[int] = field(
+ default=8,
+ metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
+ )
+ lora_alpha: Optional[float] = field(
+ default=32.0,
+ metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
+ )
+ lora_dropout: Optional[float] = field(
+ default=0.1,
+ metadata={"help": "Dropout rate for the LoRA fine-tuning."}
+ )
+ lora_target: Optional[str] = field(
+ default=None,
+ metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
+ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
+ BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
+ Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
+ Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
+ Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
+ LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
+ )
+ additional_target: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
+ )
+ resume_lora_training: Optional[bool] = field(
+ default=True,
+ metadata={
+ "help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
+ )
+ ppo_score_norm: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Use score normalization in PPO training."}
+ )
+ ppo_logger: Optional[str] = field(
+ default=None,
+ metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
+ )
+ ppo_target: Optional[float] = field(
+ default=6.0,
+ metadata={"help": "Target KL value for adaptive KL control in PPO training."}
+ )
+ dpo_beta: Optional[float] = field(
+ default=0.1,
+ metadata={"help": "The beta parameter for the DPO loss."}
+ )
+ upcast_layernorm: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Whether to upcast the layernorm weights in fp32."}
+ )
+ neft_alpha: Optional[float] = field(
+ default=0,
+ metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
+ )
+ num_workers: Optional[int] = field(
+ default=1,
+ metadata={"help": "Number of worker."}
+ )
+ storage_path: Optional[str] = field(
+ default=None,
+ metadata={"help": "storage_path is used to storage checkpoint."}
+ )
+ metrics_export_address: Optional[str] = field(
+ default=None,
+ metadata={"help": "address to export train metrics."}
+ )
+ uid: Optional[str] = field(
+ default=None,
+ metadata={"help": "finetune crd uid."}
+ )
+
+ def __post_init__(self):
+ if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
+ self.lora_target = [target.strip() for target in self.lora_target.split(",")]
+
+ if isinstance(self.additional_target, str):
+ self.additional_target = [target.strip() for target in self.additional_target.split(",")]
+
+ assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
+
+ if not self.storage_path:
+ raise ValueError("--storage_path must be specified")
+
+
+@dataclass
+class DataArguments:
+ train_path: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to train dataset"}
+ )
+
+ evaluation_path: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to evaluation dataset"}
+ )
+
+ columns: Optional[str] = field(
+ default=None,
+ metadata={"help": "columns map for dataset"}
+ )
+ block_size: Optional[int] = field(
+ default=1024,
+ metadata={"help": "length of input."}
+ )
+
+ def __post_init__(self):
+ if self.train_path is None:
+ raise ValueError("--train_path must be specified")
+
+
+def get_train_args() -> Tuple[Seq2SeqTrainingArguments, FinetuningArguments, ModelArguments, DataArguments]:
+ parser = HfArgumentParser((Seq2SeqTrainingArguments, FinetuningArguments, ModelArguments, DataArguments))
+
+ training_args, finetuning_args, model_args, data_args = parser.parse_args_into_dataclasses()
+
+ training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
+
+ # Log on each process the small summary:
+ logger.info(
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
+ + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
+ )
+
+ # Set seed before initializing model.
+ transformers.set_seed(training_args.seed)
+
+ return training_args, finetuning_args, model_args, data_args
diff --git a/pkg/tuning/prometheus/__init__.py b/pkg/tuning/prometheus/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/pkg/tuning/prometheus/metrics.py b/pkg/tuning/prometheus/metrics.py
new file mode 100644
index 0000000..4e5f794
--- /dev/null
+++ b/pkg/tuning/prometheus/metrics.py
@@ -0,0 +1,135 @@
+from typing import List, Dict
+from datetime import datetime
+from urllib.parse import urljoin
+from .prometheus_pb2 import (
+ WriteRequest,
+ TimeSeries
+)
+import calendar
+import logging
+import requests
+import snappy
+
+
+def dt2ts(dt):
+ """Converts a datetime object to UTC timestamp
+ naive datetime will be considered UTC.
+ """
+ return calendar.timegm(dt.utctimetuple())
+
+
+def write(address: str, series: List[TimeSeries]):
+ write_request = WriteRequest()
+ write_request.timeseries.extend(series)
+
+ uncompressed = write_request.SerializeToString()
+ compressed = snappy.compress(uncompressed)
+
+ url = urljoin(address, "/api/v1/write")
+ headers = {
+ "Content-Encoding": "snappy",
+ "Content-Type": "application/x-protobuf",
+ "X-Prometheus-Remote-Write-Version": "0.1.0",
+ "User-Agent": "metrics-worker"
+ }
+ try:
+ response = requests.post(url, headers=headers, data=compressed)
+ print(response)
+ except Exception as e:
+ print(e)
+
+
+def export_train_metrics(address: str, metrics: Dict):
+ series = TimeSeries()
+ label = series.labels.add()
+ label.name = "__name__"
+ label.value = "train_metrics"
+
+ label = series.labels.add()
+ label.name = "uid"
+ label.value = str(metrics["uid"])
+
+ label = series.labels.add()
+ label.name = "total_steps"
+ label.value = str(metrics.get("total_steps", ""))
+
+ label = series.labels.add()
+ label.name = "current_steps"
+ label.value = str(metrics.get("current_steps", ""))
+
+ label = series.labels.add()
+ label.name = "loss"
+ label.value = str(metrics.get("loss", ""))
+
+ label = series.labels.add()
+ label.name = "learning_rate"
+ label.value = str(metrics.get("learning_rate", ""))
+
+ label = series.labels.add()
+ label.name = "epoch"
+ label.value = str(metrics.get("epoch", ""))
+
+ sample = series.samples.add()
+ sample.value = 1
+ sample.timestamp = dt2ts(datetime.utcnow()) * 1000
+
+ write(address, [series])
+
+
+def export_eval_metrics(address: str, metrics: Dict):
+ series = TimeSeries()
+ label = series.labels.add()
+ label.name = "__name__"
+ label.value = "eval_metrics"
+
+ label = series.labels.add()
+ label.name = "uid"
+ label.value = str(metrics["uid"])
+
+ label = series.labels.add()
+ label.name = "total_steps"
+ label.value = str(metrics.get("total_steps", ""))
+
+ label = series.labels.add()
+ label.name = "current_steps"
+ label.value = str(metrics.get("current_steps", ""))
+
+ label = series.labels.add()
+ label.name = "eval_loss"
+ label.value = str(metrics.get("eval_loss", ""))
+
+ label = series.labels.add()
+ label.name = "eval_perplexity"
+ label.value = str(metrics.get("eval_perplexity", ""))
+
+ label = series.labels.add()
+ label.name = "epoch"
+ label.value = str(metrics.get("epoch", ""))
+
+ sample = series.samples.add()
+ sample.value = 1
+ sample.timestamp = dt2ts(datetime.utcnow()) * 1000
+
+ write(address, [series])
+
+
+if __name__ == '__main__':
+ train_metrics = {
+ "uid": "1",
+ "current_steps": 10,
+ "total_steps": 84,
+ "loss": 3.088,
+ "learning_rate": 4.404761904761905e-05,
+ "epoch": 0.71
+ }
+ export_train_metrics("http://10.33.1.10:30722", train_metrics)
+
+ eval_metrics = {
+ "uid": "1",
+ "current_steps": 10,
+ "total_steps": 84,
+ "eval_loss": 3.088,
+ "eval_perplexity": 4.404761904761905e-05,
+ "epoch": 0.71
+ }
+ export_eval_metrics("http://10.33.1.10:30722", eval_metrics)
diff --git a/pkg/tuning/prometheus/prometheus.proto b/pkg/tuning/prometheus/prometheus.proto
new file mode 100644
index 0000000..38eed5b
--- /dev/null
+++ b/pkg/tuning/prometheus/prometheus.proto
@@ -0,0 +1,81 @@
+// Copyright 2016 Prometheus Team
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+package prometheus;
+
+option go_package = "prompb";
+
+message WriteRequest {
+ repeated prometheus.TimeSeries timeseries = 1;
+}
+
+message ReadRequest {
+ repeated Query queries = 1;
+}
+
+message ReadResponse {
+ // In same order as the request's queries.
+ repeated QueryResult results = 1;
+}
+
+message Query {
+ int64 start_timestamp_ms = 1;
+ int64 end_timestamp_ms = 2;
+ repeated prometheus.LabelMatcher matchers = 3;
+ prometheus.ReadHints hints = 4;
+}
+
+message QueryResult {
+ // Samples within a time series must be ordered by time.
+ repeated prometheus.TimeSeries timeseries = 1;
+}
+
+message Sample {
+ double value = 1;
+ int64 timestamp = 2;
+}
+
+message TimeSeries {
+ repeated Label labels = 1;
+ repeated Sample samples = 2;
+}
+
+message Label {
+ string name = 1;
+ string value = 2;
+}
+
+message Labels {
+ repeated Label labels = 1;
+}
+
+// Matcher specifies a rule, which can match or set of labels or not.
+message LabelMatcher {
+ enum Type {
+ EQ = 0;
+ NEQ = 1;
+ RE = 2;
+ NRE = 3;
+ }
+ Type type = 1;
+ string name = 2;
+ string value = 3;
+}
+
+message ReadHints {
+ int64 step_ms = 1; // Query step size in milliseconds.
+ string func = 2; // String representation of surrounding function or aggregation.
+ int64 start_ms = 3; // Start time in milliseconds.
+ int64 end_ms = 4; // End time in milliseconds.
+}
\ No newline at end of file
diff --git a/pkg/tuning/prometheus/prometheus_pb2.py b/pkg/tuning/prometheus/prometheus_pb2.py
new file mode 100644
index 0000000..d51c3eb
--- /dev/null
+++ b/pkg/tuning/prometheus/prometheus_pb2.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: prometheus.proto
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10prometheus.proto\x12\nprometheus\":\n\x0cWriteRequest\x12*\n\ntimeseries\x18\x01 \x03(\x0b\x32\x16.prometheus.TimeSeries\"1\n\x0bReadRequest\x12\"\n\x07queries\x18\x01 \x03(\x0b\x32\x11.prometheus.Query\"8\n\x0cReadResponse\x12(\n\x07results\x18\x01 \x03(\x0b\x32\x17.prometheus.QueryResult\"\x8f\x01\n\x05Query\x12\x1a\n\x12start_timestamp_ms\x18\x01 \x01(\x03\x12\x18\n\x10\x65nd_timestamp_ms\x18\x02 \x01(\x03\x12*\n\x08matchers\x18\x03 \x03(\x0b\x32\x18.prometheus.LabelMatcher\x12$\n\x05hints\x18\x04 \x01(\x0b\x32\x15.prometheus.ReadHints\"9\n\x0bQueryResult\x12*\n\ntimeseries\x18\x01 \x03(\x0b\x32\x16.prometheus.TimeSeries\"*\n\x06Sample\x12\r\n\x05value\x18\x01 \x01(\x01\x12\x11\n\ttimestamp\x18\x02 \x01(\x03\"T\n\nTimeSeries\x12!\n\x06labels\x18\x01 \x03(\x0b\x32\x11.prometheus.Label\x12#\n\x07samples\x18\x02 \x03(\x0b\x32\x12.prometheus.Sample\"$\n\x05Label\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"+\n\x06Labels\x12!\n\x06labels\x18\x01 \x03(\x0b\x32\x11.prometheus.Label\"\x82\x01\n\x0cLabelMatcher\x12+\n\x04type\x18\x01 \x01(\x0e\x32\x1d.prometheus.LabelMatcher.Type\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\t\"(\n\x04Type\x12\x06\n\x02\x45Q\x10\x00\x12\x07\n\x03NEQ\x10\x01\x12\x06\n\x02RE\x10\x02\x12\x07\n\x03NRE\x10\x03\"L\n\tReadHints\x12\x0f\n\x07step_ms\x18\x01 \x01(\x03\x12\x0c\n\x04\x66unc\x18\x02 \x01(\t\x12\x10\n\x08start_ms\x18\x03 \x01(\x03\x12\x0e\n\x06\x65nd_ms\x18\x04 \x01(\x03\x42\x08Z\x06prompbb\x06proto3')
+
+
+
+_WRITEREQUEST = DESCRIPTOR.message_types_by_name['WriteRequest']
+_READREQUEST = DESCRIPTOR.message_types_by_name['ReadRequest']
+_READRESPONSE = DESCRIPTOR.message_types_by_name['ReadResponse']
+_QUERY = DESCRIPTOR.message_types_by_name['Query']
+_QUERYRESULT = DESCRIPTOR.message_types_by_name['QueryResult']
+_SAMPLE = DESCRIPTOR.message_types_by_name['Sample']
+_TIMESERIES = DESCRIPTOR.message_types_by_name['TimeSeries']
+_LABEL = DESCRIPTOR.message_types_by_name['Label']
+_LABELS = DESCRIPTOR.message_types_by_name['Labels']
+_LABELMATCHER = DESCRIPTOR.message_types_by_name['LabelMatcher']
+_READHINTS = DESCRIPTOR.message_types_by_name['ReadHints']
+_LABELMATCHER_TYPE = _LABELMATCHER.enum_types_by_name['Type']
+WriteRequest = _reflection.GeneratedProtocolMessageType('WriteRequest', (_message.Message,), {
+ 'DESCRIPTOR' : _WRITEREQUEST,
+ '__module__' : 'prometheus_pb2'
+ # @@protoc_insertion_point(class_scope:prometheus.WriteRequest)
+ })
+_sym_db.RegisterMessage(WriteRequest)
+
+ReadRequest = _reflection.GeneratedProtocolMessageType('ReadRequest', (_message.Message,), {
+ 'DESCRIPTOR' : _READREQUEST,
+ '__module__' : 'prometheus_pb2'
+ # @@protoc_insertion_point(class_scope:prometheus.ReadRequest)
+ })
+_sym_db.RegisterMessage(ReadRequest)
+
+ReadResponse = _reflection.GeneratedProtocolMessageType('ReadResponse', (_message.Message,), {
+ 'DESCRIPTOR' : _READRESPONSE,
+ '__module__' : 'prometheus_pb2'
+ # @@protoc_insertion_point(class_scope:prometheus.ReadResponse)
+ })
+_sym_db.RegisterMessage(ReadResponse)
+
+Query = _reflection.GeneratedProtocolMessageType('Query', (_message.Message,), {
+ 'DESCRIPTOR' : _QUERY,
+ '__module__' : 'prometheus_pb2'
+ # @@protoc_insertion_point(class_scope:prometheus.Query)
+ })
+_sym_db.RegisterMessage(Query)
+
+QueryResult = _reflection.GeneratedProtocolMessageType('QueryResult', (_message.Message,), {
+ 'DESCRIPTOR' : _QUERYRESULT,
+ '__module__' : 'prometheus_pb2'
+ # @@protoc_insertion_point(class_scope:prometheus.QueryResult)
+ })
+_sym_db.RegisterMessage(QueryResult)
+
+Sample = _reflection.GeneratedProtocolMessageType('Sample', (_message.Message,), {
+ 'DESCRIPTOR' : _SAMPLE,
+ '__module__' : 'prometheus_pb2'
+ # @@protoc_insertion_point(class_scope:prometheus.Sample)
+ })
+_sym_db.RegisterMessage(Sample)
+
+TimeSeries = _reflection.GeneratedProtocolMessageType('TimeSeries', (_message.Message,), {
+ 'DESCRIPTOR' : _TIMESERIES,
+ '__module__' : 'prometheus_pb2'
+ # @@protoc_insertion_point(class_scope:prometheus.TimeSeries)
+ })
+_sym_db.RegisterMessage(TimeSeries)
+
+Label = _reflection.GeneratedProtocolMessageType('Label', (_message.Message,), {
+ 'DESCRIPTOR' : _LABEL,
+ '__module__' : 'prometheus_pb2'
+ # @@protoc_insertion_point(class_scope:prometheus.Label)
+ })
+_sym_db.RegisterMessage(Label)
+
+Labels = _reflection.GeneratedProtocolMessageType('Labels', (_message.Message,), {
+ 'DESCRIPTOR' : _LABELS,
+ '__module__' : 'prometheus_pb2'
+ # @@protoc_insertion_point(class_scope:prometheus.Labels)
+ })
+_sym_db.RegisterMessage(Labels)
+
+LabelMatcher = _reflection.GeneratedProtocolMessageType('LabelMatcher', (_message.Message,), {
+ 'DESCRIPTOR' : _LABELMATCHER,
+ '__module__' : 'prometheus_pb2'
+ # @@protoc_insertion_point(class_scope:prometheus.LabelMatcher)
+ })
+_sym_db.RegisterMessage(LabelMatcher)
+
+ReadHints = _reflection.GeneratedProtocolMessageType('ReadHints', (_message.Message,), {
+ 'DESCRIPTOR' : _READHINTS,
+ '__module__' : 'prometheus_pb2'
+ # @@protoc_insertion_point(class_scope:prometheus.ReadHints)
+ })
+_sym_db.RegisterMessage(ReadHints)
+
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+ DESCRIPTOR._options = None
+ DESCRIPTOR._serialized_options = b'Z\006prompb'
+ _WRITEREQUEST._serialized_start=32
+ _WRITEREQUEST._serialized_end=90
+ _READREQUEST._serialized_start=92
+ _READREQUEST._serialized_end=141
+ _READRESPONSE._serialized_start=143
+ _READRESPONSE._serialized_end=199
+ _QUERY._serialized_start=202
+ _QUERY._serialized_end=345
+ _QUERYRESULT._serialized_start=347
+ _QUERYRESULT._serialized_end=404
+ _SAMPLE._serialized_start=406
+ _SAMPLE._serialized_end=448
+ _TIMESERIES._serialized_start=450
+ _TIMESERIES._serialized_end=534
+ _LABEL._serialized_start=536
+ _LABEL._serialized_end=572
+ _LABELS._serialized_start=574
+ _LABELS._serialized_end=617
+ _LABELMATCHER._serialized_start=620
+ _LABELMATCHER._serialized_end=750
+ _LABELMATCHER_TYPE._serialized_start=710
+ _LABELMATCHER_TYPE._serialized_end=750
+ _READHINTS._serialized_start=752
+ _READHINTS._serialized_end=828
+# @@protoc_insertion_point(module_scope)
diff --git a/pkg/tuning/requirements.txt b/pkg/tuning/requirements.txt
new file mode 100644
index 0000000..4d8fc9d
--- /dev/null
+++ b/pkg/tuning/requirements.txt
@@ -0,0 +1,10 @@
+bitsandbytes==0.41.3.post2
+datasets==2.14.5
+deepspeed==0.12.2
+evaluate==0.4.1
+peft==0.5.0
+protobuf==3.19.6
+python-snappy==0.6.1
+torch==2.1.0
+transformers==4.34.0
+
diff --git a/pkg/tuning/template.py b/pkg/tuning/template.py
new file mode 100644
index 0000000..f06a772
--- /dev/null
+++ b/pkg/tuning/template.py
@@ -0,0 +1,620 @@
+import logging
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Template:
+
+ prefix: List[Union[str, Dict[str, str]]]
+ prompt: List[Union[str, Dict[str, str]]]
+ system: str
+ sep: List[Union[str, Dict[str, str]]]
+ stop_words: List[str]
+ use_history: bool
+ efficient_eos: bool
+
+ def encode_oneturn(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ query: str,
+ resp: str,
+ history: Optional[List[Tuple[str, str]]] = None,
+ system: Optional[str] = None
+ ) -> Tuple[List[int], List[int]]:
+ r"""
+ Returns a single pair of token ids representing prompt and response respectively.
+ """
+ system, history = self._format(query, resp, history, system)
+ encoded_pairs = self._encode(tokenizer, system, history)
+ prompt_ids = []
+ for query_ids, resp_ids in encoded_pairs[:-1]:
+ prompt_ids = prompt_ids + query_ids + resp_ids
+ prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
+ return prompt_ids, answer_ids
+
+ def encode_multiturn(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ query: str,
+ resp: str,
+ history: Optional[List[Tuple[str, str]]] = None,
+ system: Optional[str] = None
+ ) -> List[Tuple[List[int], List[int]]]:
+ r"""
+ Returns multiple pairs of token ids representing prompts and responses respectively.
+ """
+ system, history = self._format(query, resp, history, system)
+ encoded_pairs = self._encode(tokenizer, system, history)
+ return encoded_pairs
+
+ def _format(
+ self,
+ query: str,
+ resp: str,
+ history: Optional[List[Tuple[str, str]]] = None,
+ system: Optional[str] = None
+ ) -> Tuple[str, List[Tuple[str, str]]]:
+ r"""
+ Aligns inputs to the standard format.
+ """
+ system = system or self.system # use system if provided
+ history = history if (history and self.use_history) else []
+ history = history + [(query, resp)]
+ return system, history
+
+ def _get_special_ids(
+ self,
+ tokenizer: "PreTrainedTokenizer"
+ ) -> Tuple[List[int], List[int]]:
+ if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
+ bos_ids = [tokenizer.bos_token_id]
+ else: # baichuan, qwen and gpt2 models have no bos token
+ bos_ids = []
+
+ if tokenizer.eos_token_id is None:
+ raise ValueError("EOS token is required.")
+
+ if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
+ eos_ids = []
+ else:
+ eos_ids = [tokenizer.eos_token_id]
+
+ return bos_ids, eos_ids
+
+ def _encode(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ system: str,
+ history: List[Tuple[str, str]]
+ ) -> List[Tuple[List[int], List[int]]]:
+ r"""
+ Encodes formatted inputs to pairs of token ids.
+ Turn 0: bos + prefix + sep + query resp + eos
+ Turn t: sep + bos + query resp + eos
+ """
+ bos_ids, eos_ids = self._get_special_ids(tokenizer)
+ sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
+ encoded_pairs = []
+ for turn_idx, (query, resp) in enumerate(history):
+ if turn_idx == 0:
+ prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
+ if len(prefix_ids) != 0: # has prefix
+ prefix_ids = bos_ids + prefix_ids + sep_ids
+ else:
+ prefix_ids = bos_ids
+ else:
+ prefix_ids = sep_ids + bos_ids
+
+ query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
+ resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
+ encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
+ return encoded_pairs
+
+ def _convert_inputs_to_ids(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ context: List[Union[str, Dict[str, str]]],
+ system: Optional[str] = None,
+ query: Optional[str] = None,
+ idx: Optional[str] = None
+ ) -> List[int]:
+ r"""
+ Converts context to token ids.
+ """
+
+ kwargs = dict(add_special_tokens=False)
+
+ token_ids = []
+ for elem in context:
+ if isinstance(elem, str):
+ elem = elem.replace("{{system}}", system, 1) if system is not None else elem
+ elem = elem.replace("{{query}}", query, 1) if query is not None else elem
+ elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
+ if len(elem) != 0:
+ token_ids = token_ids + tokenizer.encode(elem, **kwargs)
+ elif isinstance(elem, dict):
+ token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
+ else:
+ raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
+
+ return token_ids
+
+
+@dataclass
+class Llama2Template(Template):
+
+ def _encode(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ system: str,
+ history: List[Tuple[str, str]]
+ ) -> List[Tuple[List[int], List[int]]]:
+ r"""
+ Encodes formatted inputs to pairs of token ids.
+ Turn 0: bos + prefix + query resp + eos
+ Turn t: bos + query resp + eos
+ """
+ bos_ids, eos_ids = self._get_special_ids(tokenizer)
+ encoded_pairs = []
+ for turn_idx, (query, resp) in enumerate(history):
+ if turn_idx == 0: # llama2 template has no sep_ids
+ query = self.prefix[0].replace("{{system}}", system) + query
+ query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
+ resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
+ encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
+ return encoded_pairs
+
+
+templates: Dict[str, Template] = {}
+
+
+def register_template(
+ name: str,
+ prefix: List[Union[str, Dict[str, str]]],
+ prompt: List[Union[str, Dict[str, str]]],
+ system: str,
+ sep: List[Union[str, Dict[str, str]]],
+ stop_words: Optional[List[str]] = [],
+ use_history: Optional[bool] = True,
+ efficient_eos: Optional[bool] = False
+) -> None:
+ template_class = Llama2Template if "llama2" in name else Template
+ templates[name] = template_class(
+ prefix=prefix,
+ prompt=prompt,
+ system=system,
+ sep=sep,
+ stop_words=stop_words,
+ use_history=use_history,
+ efficient_eos=efficient_eos
+ )
+
+
+def get_template_and_fix_tokenizer(
+ name: str,
+ tokenizer: "PreTrainedTokenizer"
+) -> Template:
+ if tokenizer.eos_token_id is None:
+ tokenizer.eos_token = "<|endoftext|>"
+ logger.info("Add eos token: {}".format(tokenizer.eos_token))
+
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ logger.info("Add pad token: {}".format(tokenizer.pad_token))
+
+ if name is None:
+ return None
+
+ template = templates.get(name, None)
+ assert template is not None, "Template {} does not exist.".format(name)
+ tokenizer.add_special_tokens(
+ dict(additional_special_tokens=template.stop_words),
+ replace_additional_special_tokens=False
+ )
+ return template
+
+
+r"""
+Supports language model inference without histories.
+"""
+register_template(
+ name="vanilla",
+ prefix=[],
+ prompt=[
+ "{{query}}"
+ ],
+ system="",
+ sep=[],
+ use_history=False
+)
+
+
+r"""
+Default template.
+"""
+register_template(
+ name="default",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "Human: {{query}}\nAssistant: "
+ ],
+ system=(
+ "A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
+ ),
+ sep=[
+ "\n"
+ ]
+)
+
+
+r"""
+Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
+ https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
+ https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
+"""
+register_template(
+ name="llama2",
+ prefix=[
+ "<>\n{{system}}\n<>\n\n"
+ ],
+ prompt=[
+ "[INST] {{query}} [/INST] "
+ ],
+ system=(
+ "You are a helpful, respectful and honest assistant. "
+ "Always answer as helpfully as possible, while being safe. "
+ "Your answers should not include any harmful, unethical, "
+ "racist, sexist, toxic, dangerous, or illegal content. "
+ "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
+ "If a question does not make any sense, or is not factually coherent, "
+ "explain why instead of answering something not correct. "
+ "If you don't know the answer to a question, please don't share false information."
+ ),
+ sep=[]
+)
+
+
+r"""
+Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
+ https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
+"""
+register_template(
+ name="llama2_zh",
+ prefix=[
+ "<>\n{{system}}\n<>\n\n"
+ ],
+ prompt=[
+ "[INST] {{query}} [/INST] "
+ ],
+ system="You are a helpful assistant. 你是一个乐于助人的助手。",
+ sep=[]
+)
+
+
+r"""
+Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
+"""
+register_template(
+ name="alpaca",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "### Instruction:\n{{query}}\n\n### Response:\n"
+ ],
+ system=(
+ "Below is an instruction that describes a task. "
+ "Write a response that appropriately completes the request."
+ ),
+ sep=[
+ "\n\n"
+ ]
+)
+
+
+r"""
+Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5
+ https://huggingface.co/lmsys/vicuna-13b-v1.5
+"""
+register_template(
+ name="vicuna",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "USER: {{query}} ASSISTANT:"
+ ],
+ system=(
+ "A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
+ ),
+ sep=[]
+)
+
+
+r"""
+Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
+"""
+register_template(
+ name="belle",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "Human: {{query}}\n\nBelle: "
+ ],
+ system="",
+ sep=[
+ "\n\n"
+ ]
+)
+
+
+r"""
+Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
+ https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1
+ https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat
+"""
+register_template(
+ name="ziya",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ {"token": ""},
+ ":{{query}}\n",
+ {"token": ""},
+ ":"
+ ],
+ system="",
+ sep=[
+ "\n"
+ ]
+)
+
+
+r"""
+Supports: https://huggingface.co/BAAI/AquilaChat-7B
+ https://huggingface.co/BAAI/AquilaChat2-7B
+ https://huggingface.co/BAAI/AquilaChat2-34B
+"""
+register_template(
+ name="aquila",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "Human: {{query}}###Assistant:"
+ ],
+ system=(
+ "A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions."
+ ),
+ sep=[
+ "###"
+ ],
+ stop_words=[
+ ""
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/internlm/internlm-chat-7b
+ https://huggingface.co/internlm/internlm-chat-20b
+"""
+register_template(
+ name="intern",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "<|User|>:{{query}}",
+ {"token": ""},
+ "\n<|Bot|>:"
+ ],
+ system="",
+ sep=[
+ {"token": ""},
+ "\n"
+ ],
+ stop_words=[
+ ""
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
+"""
+register_template(
+ name="baichuan",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ {"token": ""}, # user token
+ "{{query}}",
+ {"token": ""} # assistant token
+ ],
+ system="",
+ sep=[],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
+ https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
+"""
+register_template(
+ name="baichuan2",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ {"token": ""}, # user token
+ "{{query}}",
+ {"token": ""} # assistant token
+ ],
+ system="",
+ sep=[],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
+ https://huggingface.co/HuggingFaceH4/starchat-beta
+"""
+register_template(
+ name="starchat",
+ prefix=[
+ {"token": "<|system|>"},
+ "\n{{system}}",
+ ],
+ prompt=[
+ {"token": "<|user|>"},
+ "\n{{query}}",
+ {"token": "<|end|>"},
+ "\n",
+ {"token": "<|assistant|>"}
+ ],
+ system="",
+ sep=[
+ {"token": "<|end|>"},
+ "\n"
+ ],
+ stop_words=[
+ "<|end|>"
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
+ https://huggingface.co/Qwen/Qwen-14B-Chat
+"""
+register_template(
+ name="chatml",
+ prefix=[
+ {"token": "<|im_start|>"},
+ "system\n{{system}}"
+ ],
+ prompt=[
+ {"token": "<|im_start|>"},
+ "user\n{{query}}",
+ {"token": "<|im_end|>"},
+ "\n",
+ {"token": "<|im_start|>"},
+ "assistant\n"
+ ],
+ system="You are a helpful assistant.",
+ sep=[
+ {"token": "<|im_end|>"},
+ "\n"
+ ],
+ stop_words=[
+ "<|im_end|>"
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/THUDM/chatglm2-6b
+"""
+register_template(
+ name="chatglm2",
+ prefix=[
+ {"token": "[gMASK]"},
+ {"token": "sop"},
+ "{{system}}"
+ ],
+ prompt=[
+ "[Round {{idx}}]\n\n问:{{query}}\n\n答:"
+ ],
+ system="",
+ sep=[
+ "\n\n"
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/THUDM/chatglm3-6b
+"""
+register_template(
+ name="chatglm3",
+ prefix=[
+ {"token": "[gMASK]"},
+ {"token": "sop"},
+ "{{system}}"
+ ],
+ prompt=[
+ {"token": "<|user|>"},
+ "\n",
+ "{{query}}",
+ {"token": "<|assistant|>"}
+ ],
+ system="",
+ sep=[],
+ stop_words=[
+ "<|user|>",
+ "<|observation|>"
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/openchat/openchat_v3.2_super
+"""
+register_template(
+ name="openchat",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "GPT4 User: {{query}}",
+ {"token": "<|end_of_turn|>"},
+ "GPT4 Assistant:"
+ ],
+ system="",
+ sep=[
+ {"token": "<|end_of_turn|>"}
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/xverse/XVERSE-7B-Chat
+ https://huggingface.co/xverse/XVERSE-13B-Chat
+"""
+register_template(
+ name="xverse",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "Human: {{query}}\n\nAssistant: "
+ ],
+ system="",
+ sep=[]
+)
diff --git a/pkg/tuning/train.py b/pkg/tuning/train.py
new file mode 100644
index 0000000..cc4c28d
--- /dev/null
+++ b/pkg/tuning/train.py
@@ -0,0 +1,393 @@
+import functools
+import json
+import logging
+import math
+import os
+import sys
+from types import MethodType
+
+import numpy as np
+from dataclasses import dataclass
+from typing import Sequence, Union, Tuple, Dict, List, Any, Generator, Optional
+
+import ray
+import ray.data
+import torch
+import evaluate
+from pandas import DataFrame
+from peft import PeftModel, LoraConfig, TaskType, get_peft_model
+from ray.air import ScalingConfig, RunConfig
+from ray.train import Checkpoint
+from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback
+from ray.train.torch import TorchTrainer, get_device, TorchConfig
+from torch import nn
+from transformers import AutoTokenizer, AutoModel, Trainer, TrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq, \
+ AutoConfig, AutoModelForCausalLM, Seq2SeqTrainingArguments, BitsAndBytesConfig
+from ray.train.huggingface import TransformersTrainer
+from transformers.integrations import is_deepspeed_zero3_enabled
+from transformers.tokenization_utils import PreTrainedTokenizer
+from datasets import load_dataset, Dataset
+
+from callback import LogCallback
+from parser import get_train_args
+from template import get_template_and_fix_tokenizer
+from trainer import SFTTrainer
+
+logging.basicConfig(level=logging.INFO)
+# logger = logging.getLogger(__name__)
+# formatter = logging.Formatter(
+# fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+# datefmt="%m/%d/%Y %H:%M:%S"
+# )
+# handler = logging.StreamHandler(sys.stdout)
+# handler.setFormatter(formatter)
+#
+# logger = logging.getLogger()
+# logger.setLevel(logging.INFO)
+# logger.addHandler(handler)
+
+cpus_per_worker = 8
+IGNORE_INDEX = -100
+cutoff_len = 1024
+
+
+def rename_columns(batch: DataFrame, columns):
+ return batch.rename(columns=columns)
+
+
+def preprocess_dataset(
+ dataset: Union["Dataset", "IterableDataset"],
+ tokenizer: "PreTrainedTokenizer",
+ training_args: "Seq2SeqTrainingArguments"
+) -> Union["Dataset", "IterableDataset"]:
+ template = get_template_and_fix_tokenizer("llama2", tokenizer)
+
+ def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
+ for i in range(len(examples["instruction"])):
+ query, response = examples["instruction"][i], examples["response"][i]
+ query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
+ history = examples["history"][i] if "history" in examples else None
+ system = examples["system"][i] if "system" in examples else None
+ yield query, response, history, system
+
+ def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
+ # build inputs with format ` X Y ` and labels with format ` ... Y `
+ # for multiturn examples, we only mask the prompt part in each prompt-response pair.
+ # print(examples)
+
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
+
+ for query, response, history, system in construct_example(examples):
+ if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
+ continue
+
+ input_ids, labels = [], []
+ for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
+ tokenizer, query, response, history, system
+ )):
+ total_len = len(source_ids) + len(target_ids)
+ max_source_len = int(cutoff_len * (len(source_ids) / total_len))
+ max_target_len = int(cutoff_len * (len(target_ids) / total_len))
+
+ if len(source_ids) > max_source_len:
+ source_ids = source_ids[:max_source_len]
+ if len(target_ids) > max_target_len:
+ target_ids = target_ids[:max_target_len]
+
+ if turn_idx != 0 and template.efficient_eos:
+ source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
+ else:
+ source_mask = [IGNORE_INDEX] * len(source_ids)
+
+ input_ids += source_ids + target_ids
+ labels += source_mask + target_ids
+
+ if template.efficient_eos:
+ input_ids += [tokenizer.eos_token_id]
+ labels += [tokenizer.eos_token_id]
+
+ if len(input_ids) > cutoff_len:
+ input_ids = input_ids[:cutoff_len]
+ labels = labels[:cutoff_len]
+
+ model_inputs["input_ids"].append(input_ids)
+ model_inputs["attention_mask"].append([1] * len(input_ids))
+ model_inputs["labels"].append(labels)
+
+ return model_inputs
+
+ def print_supervised_dataset_example(example):
+ print("input_ids:\n{}".format(example["input_ids"]))
+ print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
+ print("label_ids:\n{}".format(example["labels"]))
+ print("labels:\n{}".format(
+ tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
+ ))
+
+ preprocess_func = preprocess_supervised_dataset
+ print_function = print_supervised_dataset_example
+ new_dataset = dataset.map_batches(preprocess_func)
+ if training_args.should_log:
+ try:
+ print_function(new_dataset.take(1)[0])
+ except StopIteration:
+ raise RuntimeError("Empty dataset!")
+ return new_dataset
+
+
+def trainer_init_per_worker(config):
+ print("--- train_task, pid: ", os.getpid())
+
+ cuda_visible_device = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
+ print("CUDA_VISIBLE_DEVICES", os.environ["CUDA_VISIBLE_DEVICES"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ print("local_rank:", local_rank)
+ device_id = cuda_visible_device[local_rank]
+ print("device_id:", device_id)
+ os.environ["CUDA_VISIBLE_DEVICES"] = f"{device_id}"
+ torch.cuda.set_device(int(device_id))
+
+ # device setting
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print("device:", torch.cuda.current_device())
+ device_ids = torch._utils._get_all_device_indices()
+ print("device_ids:", device_ids)
+ if len(device_ids) <= 0:
+ print("invalid device_ids, exit")
+ return
+
+ training_args = config.get("training_args", None)
+ finetuning_args = config.get("finetuning_args", None)
+ model_args = config.get("model_args", None)
+ data_args = config.get("data_args", None)
+ tokenizer = config.get("tokenizer", None)
+
+ # read dataset
+ train_ds = ray.train.get_dataset_shard("train")
+ print(f"train_ds: {train_ds}")
+
+ def train_gen():
+ for row in train_ds.iter_rows():
+ yield row
+
+ train_dataset = Dataset.from_generator(train_gen)
+ print(train_dataset)
+ print('------')
+ print(train_dataset[0])
+
+ eval_ds = ray.train.get_dataset_shard("evaluation")
+ print(f"eval_ds: {eval_ds}")
+
+ def eval_gen():
+ for row in eval_ds.iter_rows():
+ yield row
+
+ eval_dataset = None
+ evaluation_strategy = "no"
+ if eval_ds:
+ eval_dataset = Dataset.from_generator(eval_gen)
+ print(eval_dataset)
+ evaluation_strategy = "steps"
+
+ train_ds_len = len(list(train_ds.iter_batches(batch_size=1)))
+ steps_per_epoch = math.ceil(train_ds_len / training_args.per_device_train_batch_size)
+ print(f"train_ds_len: {train_ds_len}, steps_per_epoch: {steps_per_epoch}")
+
+ new_training_args = Seq2SeqTrainingArguments(
+ training_args.output_dir,
+ logging_steps=10,
+ save_strategy="no",
+ evaluation_strategy=evaluation_strategy,
+ num_train_epochs=training_args.num_train_epochs,
+ learning_rate=training_args.learning_rate,
+ weight_decay=training_args.weight_decay,
+ warmup_steps=training_args.warmup_steps,
+ per_device_train_batch_size=training_args.per_device_train_batch_size,
+ per_device_eval_batch_size=training_args.per_device_eval_batch_size,
+ optim=training_args.optim,
+ lr_scheduler_type=training_args.lr_scheduler_type,
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
+ push_to_hub=False,
+ report_to="none",
+ disable_tqdm=False, # declutter the output a little
+ fp16=training_args.fp16,
+ gradient_checkpointing=True,
+ deepspeed=training_args.deepspeed,
+ log_level="info",
+ )
+
+ print(f"new_training_args: {new_training_args}".replace("\n", " "))
+
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path)
+ compute_dtype = getattr(config, "torch_dtype", None)
+
+ if model_args.quantization == "int4":
+ quantization_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ bnb_4bit_use_double_quant=False,
+ )
+ elif model_args.quantization == "int8":
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+ else:
+ quantization_config = None
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ torch_dtype=compute_dtype,
+ quantization_config=quantization_config,
+ low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
+ )
+
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+ def make_inputs_require_grad(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
+ output.requires_grad_(True)
+
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ model.gradient_checkpointing_enable()
+ model.config.use_cache = False # turn off when gradient checkpointing is enabled
+ print("Gradient checkpointing enabled.")
+
+ output_layer_name = "lm_head"
+
+ if hasattr(model, output_layer_name):
+ output_layer = getattr(model, output_layer_name)
+ if isinstance(output_layer, torch.nn.Linear):
+ def forward_in_fp32(self, x: torch.Tensor) -> torch.Tensor:
+ return output_layer.__class__.forward(self, x.to(output_layer.weight.dtype)).to(torch.float32)
+
+ output_layer.forward = MethodType(forward_in_fp32, output_layer)
+
+ target_modules = finetuning_args.lora_target
+
+ lora_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ inference_mode=False,
+ r=finetuning_args.lora_rank,
+ lora_alpha=finetuning_args.lora_alpha,
+ lora_dropout=finetuning_args.lora_dropout,
+ target_modules=target_modules,
+ modules_to_save=finetuning_args.additional_target
+ )
+ model = get_peft_model(model, lora_config)
+ if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923
+ model.base_model.peft_config = model.peft_config
+ model.train()
+
+ data_collator = DataCollatorForSeq2Seq(
+ tokenizer=tokenizer,
+ pad_to_multiple_of=4, # for shift short attention
+ label_pad_token_id=IGNORE_INDEX
+ )
+
+ trainer = SFTTrainer(
+ model=model,
+ args=new_training_args,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ callbacks=[LogCallback(metrics_export_address=finetuning_args.metrics_export_address, uid=finetuning_args.uid)],
+ )
+
+ trainer = prepare_trainer(trainer)
+ train_result = trainer.train()
+ trainer.save_model(training_args.output_dir)
+
+ checkpoint = None
+ if ray.train.get_context().get_world_rank() == 0:
+ checkpoint = Checkpoint.from_directory(training_args.output_dir)
+ ray.train.report(metrics=train_result.metrics, checkpoint=checkpoint)
+
+
+def main():
+ print("init")
+ ray.init()
+
+ training_args, finetuning_args, model_args, data_args = get_train_args()
+
+ print(f"training_args: {training_args}".replace("\n", " "))
+ print(finetuning_args)
+ print(model_args)
+ print(data_args)
+
+ model_path = model_args.model_name_or_path
+ use_gpu = True
+ num_workers = finetuning_args.num_workers
+
+ if data_args.block_size > 0:
+ global cutoff_len
+ cutoff_len = data_args.block_size
+
+ # read dataset
+ print("preprocess_dataset")
+ columns_map = {
+ "instruction": "instruction",
+ "output": "response"
+ }
+ if data_args.columns:
+ print(data_args.columns)
+ columns_map.update({v: k for k, v in json.loads(data_args.columns).items()})
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+
+ train_dataset = ray.data.read_csv(data_args.train_path). \
+ map_batches(rename_columns, fn_args=[columns_map], batch_format="pandas")
+ print(train_dataset)
+ train_dataset = preprocess_dataset(train_dataset, tokenizer, training_args)
+
+ input_datasets = {"train": train_dataset}
+
+ if data_args.evaluation_path:
+ evaluation_dataset = ray.data.read_csv(data_args.train_path). \
+ map_batches(rename_columns, fn_args=[columns_map], batch_format="pandas")
+ print(evaluation_dataset)
+ evaluation_dataset = preprocess_dataset(evaluation_dataset, tokenizer, training_args)
+ input_datasets["evaluation"] = evaluation_dataset
+
+ scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu,
+ resources_per_worker={"GPU": 1, "CPU": cpus_per_worker},
+ trainer_resources={"GPU": 0}
+ )
+
+ ray_trainer = TorchTrainer(
+ train_loop_per_worker=trainer_init_per_worker,
+ train_loop_config={
+ "training_args": training_args,
+ "finetuning_args": finetuning_args,
+ "model_args": model_args,
+ "data_args": data_args,
+ "tokenizer": tokenizer,
+ },
+ scaling_config=scaling_config,
+ datasets=input_datasets,
+ run_config=RunConfig(
+ storage_path=finetuning_args.storage_path,
+ # checkpoint_config=ray.train.CheckpointConfig(
+ # num_to_keep=1,
+ # checkpoint_score_attribute="eval_loss",
+ # checkpoint_score_order="min",
+ # ),
+ )
+ )
+ result = ray_trainer.fit()
+ checkpoint_path = result.checkpoint.path
+
+ print(f"result path {checkpoint_path}")
+
+ file_path = "/home/ray/checkpoint_path"
+
+ directory = os.path.dirname(file_path)
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+ with open(file_path, 'w', encoding='utf-8') as file:
+ file.write(checkpoint_path)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/pkg/tuning/trainer.py b/pkg/tuning/trainer.py
new file mode 100644
index 0000000..6a168e0
--- /dev/null
+++ b/pkg/tuning/trainer.py
@@ -0,0 +1,507 @@
+import os
+import json
+import torch, time, math
+from copy import deepcopy
+from pathlib import Path
+
+import numpy as np
+import torch.nn as nn
+from torch.utils.data import Dataset
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, Callable
+from transformers import Seq2SeqTrainer, Trainer
+
+if TYPE_CHECKING:
+ from transformers.data.data_collator import DataCollator
+ from transformers.modeling_utils import PreTrainedModel
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
+ from transformers.trainer_callback import TrainerCallback
+ from transformers.training_args import TrainingArguments
+
+from transformers.generation.configuration_utils import GenerationConfig
+from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+from transformers.utils import logging
+from transformers.trainer_utils import EvalPrediction, PredictionOutput, speed_metrics
+
+IGNORE_INDEX = -100
+logger = logging.get_logger(__name__)
+
+
+class GenEvalSeq2SeqTrainer(Seq2SeqTrainer):
+ def __init__(
+ self,
+ model: Union["PreTrainedModel", nn.Module] = None,
+ args: "TrainingArguments" = None,
+ data_collator: Optional["DataCollator"] = None,
+ train_dataset: Optional[Dataset] = None,
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
+ tokenizer: Optional["PreTrainedTokenizerBase"] = None,
+ model_init: Optional[Callable[[], "PreTrainedModel"]] = None,
+ compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None,
+ callbacks: Optional[List["TrainerCallback"]] = None,
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+ gen_args: Optional[Any] = None, # ADD
+ ):
+ self.gen_args = gen_args
+
+ super().__init__(
+ model=model,
+ args=args,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ tokenizer=tokenizer,
+ model_init=model_init,
+ compute_metrics=compute_metrics,
+ callbacks=callbacks,
+ optimizers=optimizers,
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
+ )
+ # Override self.model.generation_config if a GenerationConfig is specified in args.
+ # Priority: args.generation_config > model.generation_config > default GenerationConfig.
+ if self.args.generation_config is not None:
+ gen_config = self.load_generation_config(self.args.generation_config)
+ self.model.generation_config = gen_config
+ print(f'trainer init : tokenizer{tokenizer}')
+
+ def evaluate(
+ self,
+ eval_dataset: Optional[Dataset] = None,
+ ignore_keys: Optional[List[str]] = None,
+ metric_key_prefix: str = "eval",
+ **gen_kwargs,
+ ) -> Dict[str, float]:
+
+ # force left pad
+ bak_tp = self.tokenizer.padding_side
+ bak_dc = self.data_collator
+ print(f'**** evaluate origin pad style: {self.tokenizer.padding_side} ****', '\n')
+
+ if self.gen_args is not None and (gen_kwargs is None or len(gen_kwargs) == 0):
+ logger.info("*" * 5 + "Using Initial Trainer gen_args" + "*" * 5)
+ gen_kwargs = self.gen_args.copy()
+ else:
+ logger.info("*" * 5 + "Using Default Trainer gen_kwargs" + "*" * 5)
+ gen_kwargs = gen_kwargs.copy()
+ # 添加您自己的逻辑...
+ # 在这里扩展 gen_kwargs 或通过其他方法修改参数
+
+ # 调用父类的 evaluate 方法并返回结果
+ self.data_collator = left_collat_fn(self.tokenizer)
+ # @ eval_dataset 的 datacollator 用的是 right
+ eval_metrics = super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix,
+ **gen_kwargs)
+
+ self.tokenizer.padding_side = bak_tp
+ self.data_collator = bak_dc
+ return eval_metrics
+
+ def prediction_step(
+ self,
+ model: nn.Module,
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[List[str]] = None,
+ ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
+ r"""
+ Removes the prompt part in the generated tokens.
+
+ Subclass and override to inject custom behavior.
+ """
+ labels = inputs["labels"].clone() if "labels" in inputs else None # backup labels
+ # force left pad
+ if self.args.predict_with_generate:
+ self.tokenizer.padding_side = "left"
+ assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
+ prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
+ if prompt_len > label_len:
+ inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
+ if label_len > prompt_len:
+ inputs["labels"] = inputs["labels"][:, :prompt_len] # truncate the labels instead of padding the inputs
+
+ loss, generated_tokens, _ = super().prediction_step(
+ model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
+ )
+ if generated_tokens is not None and self.args.predict_with_generate:
+ generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
+ generated_tokens = generated_tokens.contiguous()
+
+ return loss, generated_tokens, labels
+
+ def _pad_tensors_to_target_len(
+ self,
+ src_tensor: torch.Tensor,
+ tgt_tensor: torch.Tensor
+ ) -> torch.Tensor:
+ r"""
+ Pads the tensor to the same length as the target tensor.
+ """
+ assert self.tokenizer.pad_token_id is not None, "Pad token is required."
+ padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
+ padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
+ return padded_tensor.contiguous() # in contiguous memory
+
+ def save_predictions(
+ self,
+ predict_results: "PredictionOutput"
+ ) -> None:
+ r"""
+ Saves model predictions to `output_dir`.
+ A custom behavior that not contained in Seq2SeqTrainer.
+ 自定义行为
+ """
+ if not self.is_world_process_zero():
+ return
+
+ output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
+ logger.info(f"Saving prediction results to {output_prediction_file}")
+
+ preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions,
+ self.tokenizer.pad_token_id)
+ labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids,
+ self.tokenizer.pad_token_id)
+
+ decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
+ decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True,
+ clean_up_tokenization_spaces=True)
+
+ with open(output_prediction_file, "w", encoding="utf-8") as writer:
+ res: List[str] = []
+ for pred, label in zip(decoded_preds, decoded_labels):
+ res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
+ writer.write("\n".join(res))
+
+
+class SFTTrainer(Trainer):
+ def __init__(
+ self,
+ model: Union["PreTrainedModel", nn.Module] = None,
+ args: "TrainingArguments" = None,
+ data_collator: Optional["DataCollator"] = None,
+ train_dataset: Optional[Dataset] = None,
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
+ tokenizer: Optional["PreTrainedTokenizerBase"] = None,
+ model_init: Optional[Callable[[], "PreTrainedModel"]] = None,
+ compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None,
+ callbacks: Optional[List["TrainerCallback"]] = None,
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+ ):
+ super().__init__(
+ model=model,
+ args=args,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ tokenizer=tokenizer,
+ model_init=model_init,
+ compute_metrics=compute_metrics,
+ callbacks=callbacks,
+ optimizers=optimizers,
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
+ )
+
+ # Override self.model.generation_config if a GenerationConfig is specified in args.
+ # Priority: args.generation_config > model.generation_config > default GenerationConfig.
+ if self.args.generation_config is not None:
+ gen_config = self.load_generation_config(self.args.generation_config)
+ self.model.generation_config = gen_config
+
+ @staticmethod
+ def load_generation_config(gen_config_arg: Union[str, GenerationConfig]) -> GenerationConfig:
+ """
+ Loads a `~generation.GenerationConfig` from the `Seq2SeqTrainingArguments.generation_config` arguments.
+
+ Args:
+ gen_config_arg (`str` or [`~generation.GenerationConfig`]):
+ `Seq2SeqTrainingArguments.generation_config` argument.
+
+ Returns:
+ A `~generation.GenerationConfig`.
+ """
+
+ # GenerationConfig provided, nothing to do
+ if isinstance(gen_config_arg, GenerationConfig):
+ return deepcopy(gen_config_arg)
+
+ # str or Path
+ pretrained_model_name = Path(gen_config_arg) if isinstance(gen_config_arg, str) else gen_config_arg
+ config_file_name = None
+
+ # Figuring if it is path pointing to a file, pointing to a directory or else a model id or URL
+ # This step is required in order to determine config_file_name
+ if pretrained_model_name.is_file():
+ config_file_name = pretrained_model_name.name
+ pretrained_model_name = pretrained_model_name.parent
+ # dir path
+ elif pretrained_model_name.is_dir():
+ pass
+ # model id or URL
+ else:
+ pretrained_model_name = gen_config_arg
+
+ gen_config = GenerationConfig.from_pretrained(pretrained_model_name, config_file_name)
+ return gen_config
+
+ def evaluate(
+ self,
+ eval_dataset: Optional[Dataset] = None,
+ ignore_keys: Optional[List[str]] = None,
+ metric_key_prefix: str = "eval",
+ **gen_kwargs,
+ ) -> Dict[str, float]:
+ """
+ Run evaluation and returns metrics.
+
+ The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
+ (pass it to the init `compute_metrics` argument).
+
+ You can also subclass and override this method to inject custom behavior.
+
+ Args:
+ eval_dataset (`Dataset`, *optional*):
+ Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
+ not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
+ method.
+ ignore_keys (`List[str]`, *optional*):
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions.
+ metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
+ An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+ "eval_bleu" if the prefix is `"eval"` (default)
+ max_length (`int`, *optional*):
+ The maximum target length to use when predicting with the generate method.
+ num_beams (`int`, *optional*):
+ Number of beams for beam search that will be used when predicting with the generate method. 1 means no
+ beam search.
+ gen_kwargs:
+ Additional `generate` specific kwargs.
+
+ Returns:
+ A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
+ dictionary also contains the epoch number which comes from the training state.
+ """
+
+ gen_kwargs = gen_kwargs.copy()
+ if (
+ gen_kwargs.get("max_length") is None
+ and gen_kwargs.get("max_new_tokens") is None
+ and self.args.generation_max_length is not None
+ ):
+ gen_kwargs["max_length"] = self.args.generation_max_length
+ if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
+ gen_kwargs["num_beams"] = self.args.generation_num_beams
+ self._gen_kwargs = gen_kwargs
+
+ # **** Original Code Start ****
+ self._memory_tracker.start()
+
+ eval_dataloader = self.get_eval_dataloader(eval_dataset)
+ start_time = time.time()
+
+ eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
+ output = eval_loop(
+ eval_dataloader,
+ description="Evaluation",
+ prediction_loss_only=True if self.compute_metrics is None else None,
+ ignore_keys=ignore_keys,
+ metric_key_prefix=metric_key_prefix,
+ )
+
+ total_batch_size = self.args.eval_batch_size * self.args.world_size
+ if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
+ start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
+ output.metrics.update(
+ speed_metrics(
+ metric_key_prefix,
+ start_time,
+ num_samples=output.num_samples,
+ num_steps=math.ceil(output.num_samples / total_batch_size),
+ )
+ )
+ # # **** Original Code End ****
+ # @ compute perplexity
+ if 'eval_loss' in output.metrics.keys():
+ mean_loss = output.metrics.get('eval_loss')
+ perplexity = math.exp(mean_loss)
+ output.metrics['eval_perplexity'] = perplexity
+
+ # logger.info(output.metrics)
+ self.log(output.metrics)
+ # 加入到 state.log_history 控制台里面
+
+ # if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+ # # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+ # xm.master_print(met.metrics_report())
+ # @ on_evaluate
+ self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
+ self._memory_tracker.stop_and_update_metrics(output.metrics)
+
+ return output.metrics
+
+ def predict(
+ self,
+ test_dataset: Dataset,
+ ignore_keys: Optional[List[str]] = None,
+ metric_key_prefix: str = "test",
+ **gen_kwargs,
+ ) -> "PredictionOutput":
+ """
+ Run prediction and returns predictions and potential metrics.
+
+ Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
+ will also return metrics, like in `evaluate()`.
+
+ Args:
+ test_dataset (`Dataset`):
+ Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
+ `model.forward()` method are automatically removed. Has to implement the method `__len__`
+ ignore_keys (`List[str]`, *optional*):
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions.
+ metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
+ An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+ "eval_bleu" if the prefix is `"eval"` (default)
+ max_length (`int`, *optional*):
+ The maximum target length to use when predicting with the generate method.
+ num_beams (`int`, *optional*):
+ Number of beams for beam search that will be used when predicting with the generate method. 1 means no
+ beam search.
+ gen_kwargs:
+ Additional `generate` specific kwargs.
+
+
+
+ If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
+ padding in a token classification task) the predictions will be padded (on the right) to allow for
+ concatenation into one array. The padding index is -100.
+
+
+
+ Returns: *NamedTuple* A namedtuple with the following keys:
+
+ - predictions (`np.ndarray`): The predictions on `test_dataset`.
+ - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
+ - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
+ labels).
+ """
+
+ gen_kwargs = gen_kwargs.copy()
+
+ # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
+ # training args
+ if (
+ gen_kwargs.get("max_length") is None
+ and gen_kwargs.get("max_new_tokens") is None
+ and self.args.generation_max_length is not None
+ ):
+ gen_kwargs["max_length"] = self.args.generation_max_length
+ if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
+ gen_kwargs["num_beams"] = self.args.generation_num_beams
+ self._gen_kwargs = gen_kwargs
+
+ return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
+
+ def prediction_step(
+ self,
+ model: nn.Module,
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[List[str]] = None,
+ **gen_kwargs,
+ ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
+
+ if not self.args.predict_with_generate or prediction_loss_only:
+ return super().prediction_step(
+ model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
+ )
+ # ADD
+ has_labels = "labels" in inputs
+ inputs = self._prepare_inputs(inputs)
+
+ # Priority (handled in generate):
+ # non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
+ if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
+ gen_kwargs = self._gen_kwargs.copy()
+ if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None:
+ gen_kwargs.pop("num_beams")
+ if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None:
+ gen_kwargs.pop("max_length")
+
+ default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
+ gen_kwargs["synced_gpus"] = (
+ gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
+ )
+
+ generation_inputs = inputs.copy()
+ # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
+ # (otherwise, it would continue generating from the padded `decoder_input_ids`)
+ if (
+ "labels" in generation_inputs
+ and "decoder_input_ids" in generation_inputs
+ and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
+ ):
+ generation_inputs = {
+ k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
+ }
+ # @ 1 generated_tokens
+ print(gen_kwargs)
+ generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
+
+ # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
+ # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
+ # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
+ if self.model.generation_config._from_model_config:
+ self.model.generation_config._from_model_config = False
+
+ # Retrieves GenerationConfig from model.generation_config
+ gen_config = self.model.generation_config
+ # in case the batch is shorter than max length, the output should be padded
+ if generated_tokens.shape[-1] < gen_config.max_length:
+ generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
+ elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
+ generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1)
+
+ # @ 2 outputs loss
+ with torch.no_grad():
+ if has_labels:
+ with self.compute_loss_context_manager():
+ outputs = model(**inputs)
+ if self.label_smoother is not None:
+ loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
+ else:
+ loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
+ else:
+ loss = None
+
+ if self.args.prediction_loss_only:
+ return loss, None, None
+
+ if has_labels:
+ labels = inputs["labels"]
+ if labels.shape[-1] < gen_config.max_length:
+ labels = self._pad_tensors_to_max_len(labels, gen_config.max_length)
+ elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1:
+ labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1)
+ else:
+ labels = None
+
+ return loss, generated_tokens, labels
+
+ def _pad_tensors_to_max_len(self, tensor, max_length):
+ if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
+ # If PAD token is not defined at least EOS token has to be defined
+ pad_token_id = (
+ self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
+ )
+ else:
+ if self.model.config.pad_token_id is not None:
+ pad_token_id = self.model.config.pad_token_id
+ else:
+ raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
+
+ padded_tensor = pad_token_id * torch.ones(
+ (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
+ )
+ padded_tensor[:, : tensor.shape[-1]] = tensor
+ return padded_tensor
diff --git a/pkg/util/generate/generate.go b/pkg/util/generate/generate.go
index 8a87a58..e4dd477 100644
--- a/pkg/util/generate/generate.go
+++ b/pkg/util/generate/generate.go
@@ -4,6 +4,7 @@ import (
"fmt"
"github.com/DataTunerX/finetune-experiment-controller/pkg/config"
+ "github.com/DataTunerX/finetune-experiment-controller/pkg/util/label"
corev1beta1 "github.com/DataTunerX/meta-server/api/core/v1beta1"
extensionv1beta1 "github.com/DataTunerX/meta-server/api/extension/v1beta1"
finetunev1beta1 "github.com/DataTunerX/meta-server/api/finetune/v1beta1"
@@ -16,26 +17,20 @@ import (
)
const (
- defaultFinetuneImage = "rayproject/ray271-llama2-7b-finetune:20231124"
// todo llm file path
defaultFinetuneCodePath = "/tmp/llama2-7b/"
- zeroString = ""
defaultBuildImageJobContainerName = "imagebuild"
defaultBuildImageJobImage = "release.daocloud.io/datatunerx/buildimage:v0.0.1"
)
func GenerateFinetune(finetuneJob *finetunev1beta1.FinetuneJob) *finetunev1beta1.Finetune {
- if finetuneJob.Spec.FineTune.Name == "" {
- finetuneJob.Spec.FineTune.Name = fmt.Sprintf("%s-%s", finetuneJob.Name, "finetune")
- }
- if finetuneJob.Spec.FineTune.FinetuneSpec.Node <= 0 {
- finetuneJob.Spec.FineTune.FinetuneSpec.Node = 2
- }
+ finetuneLabel := label.GenerateInstanceLabel(finetuneJob.Name)
finetune := &finetunev1beta1.Finetune{
ObjectMeta: metav1.ObjectMeta{
Name: finetuneJob.Spec.FineTune.Name,
Namespace: finetuneJob.Namespace,
+ Labels: finetuneLabel,
},
Spec: finetunev1beta1.FinetuneSpec{
Dataset: finetuneJob.Spec.FineTune.FinetuneSpec.Dataset,
@@ -48,23 +43,25 @@ func GenerateFinetune(finetuneJob *finetunev1beta1.FinetuneJob) *finetunev1beta1
if finetuneJob.Spec.FineTune.FinetuneSpec.Resource != nil {
finetune.Spec.Resource = finetuneJob.Spec.FineTune.FinetuneSpec.Resource
}
- if finetuneJob.Spec.FineTune.FinetuneSpec.Image.Name == zeroString {
- finetune.Spec.Image.Name = defaultFinetuneImage
+ if finetuneJob.Spec.FineTune.FinetuneSpec.Image.Name == "" {
+ finetune.Spec.Image.Name = config.GetBaseImage()
}
- if finetuneJob.Spec.FineTune.FinetuneSpec.Image.Path == zeroString {
+ if finetuneJob.Spec.FineTune.FinetuneSpec.Image.Path == "" {
finetune.Spec.Image.Path = defaultFinetuneCodePath
}
return finetune
}
-// todo(tigerK) add build image job
-func GenerateBuildImageJob(name, namespace, filePath string) *batchv1.Job {
+func GenerateBuildImageJob(filePath, imageName, imageTag string, finetuneJobInstance *finetunev1beta1.FinetuneJob) *batchv1.Job {
privileged := true
directory := corev1.HostPathDirectory
+ buildImageJobName := fmt.Sprintf("%s-buildimage", finetuneJobInstance.Name)
+ jobLabel := label.GenerateInstanceLabel(finetuneJobInstance.Name)
return &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
- Name: name,
- Namespace: namespace,
+ Name: buildImageJobName,
+ Namespace: finetuneJobInstance.Namespace,
+ Labels: jobLabel,
},
Spec: batchv1.JobSpec{
Template: corev1.PodTemplateSpec{
@@ -113,17 +110,21 @@ func GenerateBuildImageJob(name, namespace, filePath string) *batchv1.Job {
Name: "USERNAME",
Value: config.GetUserName(),
},
+ {
+ Name: "BASE_IMAGE",
+ Value: config.GetBaseImage(),
+ },
{
Name: "PASSWORD",
Value: config.GetPassword(),
},
{
Name: "IMAGE_NAME",
- Value: config.GetImageName(),
+ Value: imageName,
},
{
Name: "IMAGE_TAG",
- Value: config.GetImageTag(),
+ Value: imageTag,
},
},
VolumeMounts: []corev1.VolumeMount{
@@ -164,6 +165,11 @@ func GenerateRayService(name, namespace, importPath, runtimeEnv, deploymentName
workReplicas := int32(1)
minWorkReplicas := int32(1)
maxWorkReplicas := int32(1)
+ if finetuneJob.Spec.ServeConfig.NodeSelector == nil {
+ finetuneJob.Spec.ServeConfig.NodeSelector = map[string]string{
+ "nvidia.com/gpu": "present",
+ }
+ }
return &rayv1.RayService{
ObjectMeta: metav1.ObjectMeta{
Name: name,
@@ -300,10 +306,12 @@ func GenerateRayService(name, namespace, importPath, runtimeEnv, deploymentName
Limits: map[corev1.ResourceName]resource.Quantity{
corev1.ResourceCPU: resource.MustParse("8"),
corev1.ResourceMemory: resource.MustParse("64Gi"),
+ "nvidia.com/gpu": resource.MustParse("1"),
},
Requests: map[corev1.ResourceName]resource.Quantity{
corev1.ResourceCPU: resource.MustParse("4"),
corev1.ResourceMemory: resource.MustParse("32Gi"),
+ "nvidia.com/gpu": resource.MustParse("1"),
},
},
},
@@ -327,12 +335,6 @@ func GenerateBuiltInScoring(name, namespace, inference string) *extensionv1beta1
Namespace: namespace,
},
Spec: extensionv1beta1.ScoringSpec{
- Questions: []extensionv1beta1.Question{
- {
- Question: "天王盖地虎",
- Reference: "小鸡炖蘑菇",
- },
- },
InferenceService: inference,
},
}
diff --git a/pkg/util/util.go b/pkg/util/util.go
index c0a1aa8..b5466a4 100644
--- a/pkg/util/util.go
+++ b/pkg/util/util.go
@@ -1,6 +1,13 @@
package util
-import "strings"
+import (
+ "io/ioutil"
+ "os"
+ "strconv"
+ "strings"
+
+ "github.com/DataTunerX/utility-server/logging"
+)
func RemoveBucketName(path, bucketName string) string {
parts := strings.Split(path, "/")
@@ -13,3 +20,23 @@ func RemoveBucketName(path, bucketName string) string {
func GenerateName() {
}
+
+func ParseScore(s string) int {
+ score, err := strconv.Atoi(s)
+ if err != nil {
+ return 0
+ }
+ return score
+}
+
+func GetOperatorNamespace() string {
+ nsBytes, err := ioutil.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/namespace")
+ if err != nil {
+ logging.ZLogger.Errorf("unable to read file, %v", err)
+ if os.IsNotExist(err) {
+ return "datatunerx-dev"
+ }
+ }
+ ns := strings.TrimSpace(string(nsBytes))
+ return ns
+}