From 0a6eb84e48c3a02f9781ad3af3fa5f7cdba23ec7 Mon Sep 17 00:00:00 2001 From: tigerK Date: Tue, 23 Jan 2024 10:17:43 +0800 Subject: [PATCH] feat(finetune-experiment-controller): add finetune-controller 1. add finetune-controller code to this resp --- .../app/controller_manager.go | 11 + .../finetune/finetune_controller.go | 786 ++++++++++++++++++ pkg/tuning/Dockerfile | 8 + pkg/tuning/README.md | 3 + pkg/tuning/build_image.sh | 1 + pkg/tuning/callback.py | 166 ++++ pkg/tuning/ds_config.json | 13 + pkg/tuning/parser.py | 266 ++++++ pkg/tuning/prometheus/__init__.py | 0 pkg/tuning/prometheus/metrics.py | 135 +++ pkg/tuning/prometheus/prometheus.proto | 81 ++ pkg/tuning/prometheus/prometheus_pb2.py | 138 +++ pkg/tuning/requirements.txt | 10 + pkg/tuning/template.py | 620 ++++++++++++++ pkg/tuning/train.py | 393 +++++++++ pkg/tuning/trainer.py | 507 +++++++++++ 16 files changed, 3138 insertions(+) create mode 100644 internal/controller/finetune/finetune_controller.go create mode 100644 pkg/tuning/Dockerfile create mode 100644 pkg/tuning/README.md create mode 100644 pkg/tuning/build_image.sh create mode 100644 pkg/tuning/callback.py create mode 100644 pkg/tuning/ds_config.json create mode 100644 pkg/tuning/parser.py create mode 100644 pkg/tuning/prometheus/__init__.py create mode 100644 pkg/tuning/prometheus/metrics.py create mode 100644 pkg/tuning/prometheus/prometheus.proto create mode 100644 pkg/tuning/prometheus/prometheus_pb2.py create mode 100644 pkg/tuning/requirements.txt create mode 100644 pkg/tuning/template.py create mode 100644 pkg/tuning/train.py create mode 100644 pkg/tuning/trainer.py diff --git a/cmd/controller-manager/app/controller_manager.go b/cmd/controller-manager/app/controller_manager.go index a3de0f6..9ab90ea 100644 --- a/cmd/controller-manager/app/controller_manager.go +++ b/cmd/controller-manager/app/controller_manager.go @@ -18,6 +18,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/client-go/kubernetes" clientgoscheme "k8s.io/client-go/kubernetes/scheme" _ "k8s.io/client-go/plugin/pkg/client/auth" ctrl "sigs.k8s.io/controller-runtime" @@ -149,6 +150,16 @@ func NewControllerManager() (manager.Manager, error) { logging.ZLogger.Errorf("Unable to create FinetuneJob controller, %v", err) return nil, err } + if err = (&finetune.FinetuneReconciler{ + Log: logging.ZLogger, + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + Clientset: kubernetes.NewForConfigOrDie(ctrl.GetConfigOrDie()), + Config: ctrl.GetConfigOrDie(), + }).SetupWithManager(mgr); err != nil { + logging.ZLogger.Errorf("Unable to create Finetune controller, %v", err) + return nil, err + } //+kubebuilder:scaffold:builder if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil { diff --git a/internal/controller/finetune/finetune_controller.go b/internal/controller/finetune/finetune_controller.go new file mode 100644 index 0000000..756519b --- /dev/null +++ b/internal/controller/finetune/finetune_controller.go @@ -0,0 +1,786 @@ +/* +Copyright 2023. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package finetune + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os" + "strconv" + "strings" + "time" + + corev1beta1 "github.com/DataTunerX/meta-server/api/core/v1beta1" + extensionv1beta1 "github.com/DataTunerX/meta-server/api/extension/v1beta1" + finetunev1beta1 "github.com/DataTunerX/meta-server/api/finetune/v1beta1" + "github.com/DataTunerX/utility-server/logging" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/common" + batchv1 "k8s.io/api/batch/v1" + "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/rand" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/scheme" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/remotecommand" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" +) + +const ( + RayVersion = "v2.7.1" + DefaultRequeueDuration = 3 * time.Second + CheckpointPath = "/home/ray/checkpoint_path" +) + +var metricsExportAddress = os.Getenv("METRICS_EXPORT_ADDRESS") +var storagePath = os.Getenv("STORAGE_PATH") + +// FinetuneReconciler reconciles a Finetune object +type FinetuneReconciler struct { + client.Client + Scheme *runtime.Scheme + Log logging.Logger + Clientset *kubernetes.Clientset + Config *rest.Config +} + +//+kubebuilder:rbac:groups=finetune.datatunerx.io,resources=finetunes,verbs=get;list;watch;create;update;patch;delete +//+kubebuilder:rbac:groups=finetune.datatunerx.io,resources=finetunes/status,verbs=get;update;patch +//+kubebuilder:rbac:groups=finetune.datatunerx.io,resources=finetunes/finalizers,verbs=update + +// Reconcile is part of the main kubernetes reconciliation loop which aims to +// move the current state of the cluster closer to the desired state. +// TODO(user): Modify the Reconcile function to compare the state specified by +// the Finetune object against the actual cluster state, and then +// perform operations to make the cluster state reflect the state specified by +// the user. +// +// For more details, check Reconcile and its Result here: +// - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.14.1/pkg/reconcile +func (r *FinetuneReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + r.Log.Infof("Reconciling Finetune: %+v", req.NamespacedName) + finetuneInstance := &finetunev1beta1.Finetune{} + + err := r.Get(ctx, req.NamespacedName, finetuneInstance) + + if err != nil { + if apierrors.IsNotFound(err) { + r.Log.Infof("Finetune: %+v not found. Ignoring since object must be deleted", req.NamespacedName) + return ctrl.Result{}, nil + } + // Error reading the object - requeue the request. + r.Log.Error("Failed to get Finetune") + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + + if finetuneInstance.ObjectMeta.DeletionTimestamp.IsZero() { + if !controllerutil.ContainsFinalizer(finetuneInstance, finetunev1beta1.FinetuneGroupFinalizer) { + controllerutil.AddFinalizer(finetuneInstance, finetunev1beta1.FinetuneGroupFinalizer) + if err := r.Update(context.Background(), finetuneInstance); err != nil { + r.Log.Errorf("Failed to update Finetune with finalizer: %v", err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + } + } else { + r.Log.Infof("Finetune: %+v is being deleted", req.NamespacedName) + controllerutil.RemoveFinalizer(finetuneInstance, finetunev1beta1.FinetuneGroupFinalizer) + if err := r.Update(context.Background(), finetuneInstance); err != nil { + r.Log.Errorf("Failed to update Finetune without finalizer: %v", err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, nil + } + + if finetuneInstance.Status.State == finetunev1beta1.FinetuneSuccessful { + r.Log.Infof("Finetune: %+v is Successful.", req.NamespacedName) + return ctrl.Result{}, nil + } + + if finetuneInstance.Status.State == finetunev1beta1.FinetuneFailed { + r.Log.Infof("Finetune: %+v is Failed.", req.NamespacedName) + return ctrl.Result{}, nil + } + + if finetuneInstance.Status.State == "" { + if err = r.updateFinetuneState(ctx, finetuneInstance, finetunev1beta1.FinetuneInit); err != nil { + r.Log.Errorf("Finetune %v update state error: %s", req.NamespacedName, err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + } + + rayJobInstance := &rayv1.RayJob{} + + err = r.Get(ctx, req.NamespacedName, rayJobInstance) + if err != nil { + if apierrors.IsNotFound(err) { + r.Log.Info("RayJob not found. Create a new one.") + err = r.createRayJob(ctx, finetuneInstance) + if err != nil { + r.Log.Errorf("Failed to create RayJob: %v", err) + if err = r.updateFinetuneState(ctx, finetuneInstance, finetunev1beta1.FinetunePending); err != nil { + r.Log.Errorf("Finetune %v update state error: %s", req.NamespacedName, err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, nil + } else { + r.Log.Error("Failed to get RayJob") + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + } + + if err = r.updateFinetuneState(ctx, finetuneInstance, finetunev1beta1.FinetuneRunning); err != nil { + r.Log.Errorf("Finetune %v update state error: %s", req.NamespacedName, err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + + if rayJobInstance.Spec.ShutdownAfterJobFinishes == false { + rayJobInstance.Spec.ShutdownAfterJobFinishes = true + r.Log.Infof("RayJob %s/%s set ShutdownAfterJobFinishes true", rayJobInstance.Namespace, rayJobInstance.Name) + if err := r.Update(ctx, rayJobInstance); err != nil { + r.Log.Errorf("Failed to update RayJob %v: %v", req.NamespacedName, err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + } + + // check rayjob status + r.Log.Infof("RayJob %s/%s status is %s", rayJobInstance.Namespace, rayJobInstance.Name, rayJobInstance.Status.JobStatus) + if rayJobInstance.Status.JobStatus == "" { + return ctrl.Result{RequeueAfter: DefaultRequeueDuration * 10}, nil + } + + // update rayjob info + if finetuneInstance.Status.RayJobInfo == nil { + rajJobInfo, err := r.getRayJobPodInfo(ctx, rayJobInstance) + if err != nil { + r.Log.Errorf("getRayJobPodInfo err: %s", err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + finetuneInstance.Status.RayJobInfo = rajJobInfo + + if err = r.Status().Update(ctx, finetuneInstance); err != nil { + r.Log.Errorf("Failed to update Finetune status %v: %v", req.NamespacedName, err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + } + + if isJobPendingOrRunning(rayJobInstance.Status.JobStatus) { + return ctrl.Result{RequeueAfter: DefaultRequeueDuration * 10}, nil + } + + if isJobStoppedOrFailed(rayJobInstance.Status.JobStatus) { + if err = r.updateFinetuneState(ctx, finetuneInstance, finetunev1beta1.FinetuneFailed); err != nil { + r.Log.Errorf("Finetune %v update state error: %s", req.NamespacedName, err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, nil + } + + if finetuneInstance.Status.LLMCheckpoint == nil { + headPod, err := r.getRayClusterHeadPod(ctx, rayJobInstance) + if err != nil { + r.Log.Errorf("getRayClusterHeadPod err: %s", err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + + checkpointPath, err := r.fetchPodFile(ctx, headPod, CheckpointPath) + if err != nil { + r.Log.Errorf("fetchPodFile err: %s", err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + + finetuneInstance.Status.LLMCheckpoint = &finetunev1beta1.Checkpoint{ + LLMCheckpointRef: GenerateLLMCheckpointName(finetuneInstance.Name), + CheckpointPath: checkpointPath, + } + if err = r.Status().Update(ctx, finetuneInstance); err != nil { + r.Log.Errorf("Failed to update Finetune status %v: %v", req.NamespacedName, err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + + } + + // create llmcheckpoit + if err := r.reconcileLLMCheckpoint(ctx, finetuneInstance); err != nil { + r.Log.Errorf("reconcileLLMCheckpoint err: %s", err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + + if err = r.updateFinetuneState(ctx, finetuneInstance, finetunev1beta1.FinetuneSuccessful); err != nil { + r.Log.Errorf("Finetune %v update state error: %s", req.NamespacedName, err) + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + + return ctrl.Result{}, nil +} + +func (r *FinetuneReconciler) getRayJobPodInfo(ctx context.Context, rayJob *rayv1.RayJob) (*finetunev1beta1.RayJobInfo, error) { + job := &batchv1.Job{} + if err := r.Get(ctx, client.ObjectKey{Namespace: rayJob.Namespace, Name: rayJob.Name}, job); err != nil { + return nil, err + } + + jobPods := &v1.PodList{} + err := r.List(ctx, jobPods, client.InNamespace(rayJob.Namespace), client.MatchingLabels{"batch.kubernetes.io/job-name": rayJob.Name}) + if err != nil { + return nil, err + } + if len(jobPods.Items) == 0 { + return nil, fmt.Errorf("RayJob: %s/%s has no pod", rayJob.Namespace, rayJob.Name) + } + var firstPod string + var podStartTime *metav1.Time + for _, pod := range jobPods.Items { + if podStartTime == nil || podStartTime.Time.After(pod.CreationTimestamp.Time) { + firstPod = pod.Name + podStartTime = &pod.CreationTimestamp + } + } + + return &finetunev1beta1.RayJobInfo{RayJobPodName: firstPod, RayJobPodContainerName: "ray-job-submitter"}, nil +} + +func (r *FinetuneReconciler) getRayClusterHeadPod(ctx context.Context, rayJob *rayv1.RayJob) (*v1.Pod, error) { + headPods := &v1.PodList{} + filterLabels := client.MatchingLabels{common.RayClusterLabelKey: rayJob.Status.RayClusterName, common.RayNodeTypeLabelKey: string(rayv1.HeadNode)} + if err := r.List(ctx, headPods, client.InNamespace(rayJob.Namespace), filterLabels); err != nil { + return nil, err + } + if len(headPods.Items) == 0 { + return nil, fmt.Errorf("RayCluster: %s/%s has no head node", rayJob.Namespace, rayJob.Status.RayClusterName) + } + + return &headPods.Items[0], nil +} + +func (r *FinetuneReconciler) fetchPodFile(ctx context.Context, pod *v1.Pod, filePath string) (string, error) { + execRequest := r.Clientset.CoreV1().RESTClient().Post(). + Resource("pods"). + Name(pod.Name). + Namespace(pod.Namespace). + SubResource("exec"). + VersionedParams(&v1.PodExecOptions{ + Container: pod.Spec.Containers[0].Name, + Command: []string{"cat", filePath}, + Stdout: true, + Stderr: true, + }, scheme.ParameterCodec) + + exec, err := remotecommand.NewSPDYExecutor(r.Config, "POST", execRequest.URL()) + if err != nil { + return "", err + } + var stdout bytes.Buffer + var stderr bytes.Buffer + err = exec.StreamWithContext(ctx, remotecommand.StreamOptions{ + Stdout: &stdout, + Stderr: &stderr, + }) + if err != nil { + return "", err + } + return stdout.String(), nil +} + +func (r *FinetuneReconciler) reconcileLLMCheckpoint(ctx context.Context, finetune *finetunev1beta1.Finetune) error { + ns := finetune.Namespace + checkpointNamespacedName := types.NamespacedName{ + Namespace: ns, + Name: finetune.Status.LLMCheckpoint.LLMCheckpointRef, + } + + checkpoint := &corev1beta1.LLMCheckpoint{} + err := r.Get(ctx, checkpointNamespacedName, checkpoint) + if err == nil { + return nil + } + + if apierrors.IsNotFound(err) { + r.Log.Infof("LLMCheckpoint: %s/%s is not found", ns, finetune.Name) + + datasetInstance, err := r.getDataset(ctx, types.NamespacedName{ns, finetune.Spec.Dataset}) + if err != nil { + return err + } + + hyperparameterInstance, err := r.getHyperparameter(ctx, types.NamespacedName{ns, finetune.Spec.Hyperparameter.HyperparameterRef}) + if err != nil { + return err + } + + llmInstance, err := r.getLLM(ctx, types.NamespacedName{ns, finetune.Spec.LLM}) + if err != nil { + return err + } + + llmCheckpointInstance, err := generateLLMCheckpoint(checkpointNamespacedName, finetune, datasetInstance, hyperparameterInstance, llmInstance) + if err != nil { + return err + } + + // Set controller reference + //if err := controllerutil.SetControllerReference(finetune, llmCheckpointInstance, r.Scheme); err != nil { + // return err + //} + + return r.Create(ctx, llmCheckpointInstance) + + } + + return err +} + +func (r *FinetuneReconciler) getLLM(ctx context.Context, namespacedName types.NamespacedName) (*corev1beta1.LLM, error) { + llmInstance := &corev1beta1.LLM{} + err := r.Get(ctx, namespacedName, llmInstance) + if err != nil { + r.Log.Errorf("Failed to get LLM %v", namespacedName) + return nil, err + } + return llmInstance, nil +} + +func (r *FinetuneReconciler) getDataset(ctx context.Context, namespacedName types.NamespacedName) (*extensionv1beta1.Dataset, error) { + datasetInstance := &extensionv1beta1.Dataset{} + err := r.Get(ctx, namespacedName, datasetInstance) + if err != nil { + r.Log.Errorf("Failed to get Dataset %v", namespacedName) + return nil, err + } + return datasetInstance, nil +} + +func (r *FinetuneReconciler) getHyperparameter(ctx context.Context, namespacedName types.NamespacedName) (*corev1beta1.Hyperparameter, error) { + hyperparameterInstance := &corev1beta1.Hyperparameter{} + err := r.Get(ctx, namespacedName, hyperparameterInstance) + if err != nil { + r.Log.Errorf("Failed to get Hyperparameter %v", namespacedName) + return nil, err + } + return hyperparameterInstance, nil +} + +// createRayJob will create the rayjob +func (r *FinetuneReconciler) createRayJob(ctx context.Context, finetune *finetunev1beta1.Finetune) error { + ns := finetune.Namespace + + datasetInstance, err := r.getDataset(ctx, types.NamespacedName{ns, finetune.Spec.Dataset}) + if err != nil { + finetune.Status.State = finetunev1beta1.FinetunePending + if err := r.Status().Update(ctx, finetune); err != nil { + return err + } + return err + } + + hyperparameterInstance, err := r.getHyperparameter(ctx, types.NamespacedName{ns, finetune.Spec.Hyperparameter.HyperparameterRef}) + if err != nil { + finetune.Status.State = finetunev1beta1.FinetunePending + if err := r.Status().Update(ctx, finetune); err != nil { + return err + } + return err + } + + newParameters := updateHyperparameters(&hyperparameterInstance.Spec.Parameters, finetune.Spec.Hyperparameter.Overrides) + r.Log.Debugf("newParameters: %+v", newParameters) + rayJobEntrypoint, err := getRayJobEntrypoint(ctx, finetune, datasetInstance, newParameters) + if err != nil { + return err + } + + r.Log.Info("create ray cluster") + rayJobInstance, err := generateRayJob(ctx, &types.NamespacedName{ns, finetune.Name}, rayJobEntrypoint, finetune) + if err != nil { + return err + } + + // Set controller reference + if err := controllerutil.SetControllerReference(finetune, rayJobInstance, r.Scheme); err != nil { + return err + } + + return r.Create(ctx, rayJobInstance) +} + +// updateFinetuneState is a method of the FinetuneReconciler struct. +// It updates the state of a Finetune instance and logs the new state. +// +// Parameters: +// ctx: The context within which the function is called. Used for timeout and cancellation signals. +// instance: The Finetune instance whose state is to be updated. +// finetuneState: The new state to be set for the Finetune instance. +// +// Returns: +// error: An error object that describes an error that occurred during the function's execution. Returns nil if the function executed successfully. +func (r *FinetuneReconciler) updateFinetuneState(ctx context.Context, instance *finetunev1beta1.Finetune, finetuneState finetunev1beta1.FinetuneState) error { + // If the current state is the same as the new state, return nil + if instance.Status.State == finetuneState { + return nil + } + // Update the state of the Finetune instance + instance.Status.State = finetuneState + // Log the new state + r.Log.Infof("Update Finetune CR Status.State: %s", finetuneState) + // Update the status of the Finetune instance in the Kubernetes API and return any error that occurs + return r.Status().Update(ctx, instance) +} + +func getRayJobEntrypoint(ctx context.Context, finetune *finetunev1beta1.Finetune, dataset *extensionv1beta1.Dataset, parameters *corev1beta1.Parameters) (string, error) { + // TODO check parameters include blank + replicas := int32(finetune.Spec.Node) + if replicas <= 0 { + replicas = 1 + } + entrypoint := []string{"python"} + entrypoint = append(entrypoint, "/tuning/train.py") + + finetunePath := finetune.Spec.Image.Path + if finetunePath == "" { + return "", fmt.Errorf("%s/%s: finetune.Spec.Image.Path is required", finetune.Namespace, finetune.Name) + } + entrypoint = append(entrypoint, "--model_name_or_path", finetunePath) + + entrypoint = append(entrypoint, "--train_path", dataset.Spec.DatasetMetadata.DatasetInfo.Subsets[0].Splits.Train.File) + + if dataset.Spec.DatasetMetadata.DatasetInfo.Subsets[0].Splits.Validate != nil && dataset.Spec.DatasetMetadata.DatasetInfo.Subsets[0].Splits.Validate.File != "" { + entrypoint = append(entrypoint, "--evaluation_path", dataset.Spec.DatasetMetadata.DatasetInfo.Subsets[0].Splits.Validate.File) + } + + featuresMapJson, err := getFeaturesMapJson(dataset.Spec.DatasetMetadata.DatasetInfo.Features) + if err != nil { + return "", err + } + if featuresMapJson != "" { + entrypoint = append(entrypoint, "--columns", strconv.Quote(featuresMapJson)) + } + + entrypoint = append(entrypoint, "--output_dir", "result") + entrypoint = append(entrypoint, "--deepspeed", "/tuning/ds_config.json") + entrypoint = append(entrypoint, "--lora_target", "q_proj,v_proj") + entrypoint = append(entrypoint, "--lr_scheduler_type", string(parameters.Scheduler)) + entrypoint = append(entrypoint, "--optim", string(parameters.Optimizer)) + + quantization := "" + if parameters.Int8 { + quantization = "int8" + } else if parameters.Int4 { + quantization = "int4" + } + if quantization != "" { + entrypoint = append(entrypoint, "--quantization", quantization) + } + + entrypoint = append(entrypoint, "--lora_r", parameters.LoRA_R) + entrypoint = append(entrypoint, "--lora_alpha", parameters.LoRA_Alpha) + entrypoint = append(entrypoint, "--lora_dropout", parameters.LoRA_Dropout) + entrypoint = append(entrypoint, "--learning_rate", parameters.LearningRate) + entrypoint = append(entrypoint, "--num_train_epochs", fmt.Sprintf("%d", parameters.Epochs)) + entrypoint = append(entrypoint, "--block_size", fmt.Sprintf("%d", parameters.BlockSize)) + entrypoint = append(entrypoint, "--per_device_train_batch_size ", fmt.Sprintf("%d", parameters.BatchSize)) + entrypoint = append(entrypoint, "--warmup_ratio", parameters.WarmupRatio) + entrypoint = append(entrypoint, "--weight_decay", parameters.WeightDecay) + entrypoint = append(entrypoint, "--gradient_accumulation_steps", fmt.Sprintf("%d", parameters.GradAccSteps)) + entrypoint = append(entrypoint, "--fp16", fmt.Sprintf("%t", parameters.FP16)) + entrypoint = append(entrypoint, "--num_workers", fmt.Sprintf("%d", replicas)) + entrypoint = append(entrypoint, "--storage_path", storagePath) + + if metricsExportAddress != "" { + entrypoint = append(entrypoint, "--metrics_export_address", metricsExportAddress) + entrypoint = append(entrypoint, "--uid", fmt.Sprintf("%s", finetune.UID)) + + } + return strings.Join(entrypoint, " "), nil +} + +func generateRayJob(ctx context.Context, namespacedName *types.NamespacedName, entrypoint string, finetune *finetunev1beta1.Finetune) (*rayv1.RayJob, error) { + replicas := int32(finetune.Spec.Node) + if replicas <= 0 { + replicas = 1 + } + if finetune.Spec.Image.Name == "" { + return nil, fmt.Errorf("%s/%s: finetune.Spec.Image.Name is required", finetune.Namespace, finetune.Name) + } + + var rayJobInstance = &rayv1.RayJob{ + TypeMeta: metav1.TypeMeta{ + Kind: "RayJob", + APIVersion: "ray.io/v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: namespacedName.Name, + Namespace: namespacedName.Namespace, + }, + Spec: rayv1.RayJobSpec{ + //ShutdownAfterJobFinishes: true, + Entrypoint: entrypoint, + RayClusterSpec: &rayv1.RayClusterSpec{ + RayVersion: RayVersion, + HeadGroupSpec: rayv1.HeadGroupSpec{ + ServiceType: "NodePort", + HeadService: nil, + EnableIngress: nil, + Replicas: nil, + RayStartParams: map[string]string{ + "dashboard-host": "0.0.0.0", + "num-gpus": "0", + }, + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "ray-head", + Image: finetune.Spec.Image.Name, + ImagePullPolicy: finetune.Spec.Image.ImagePullPolicy, + Ports: []v1.ContainerPort{ + v1.ContainerPort{ + Name: "gcs-server", + ContainerPort: 6379, + }, + v1.ContainerPort{ + Name: "dashboard", + ContainerPort: 8265, + }, + v1.ContainerPort{ + Name: "client", + ContainerPort: 10001, + }, + }, + }, + }, + }, + }, + }, + WorkerGroupSpecs: []rayv1.WorkerGroupSpec{rayv1.WorkerGroupSpec{ + GroupName: "finetune-group", + Replicas: &replicas, + RayStartParams: map[string]string{}, + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "ray-worker", + Image: finetune.Spec.Image.Name, + ImagePullPolicy: finetune.Spec.Image.ImagePullPolicy, + Lifecycle: &v1.Lifecycle{ + PreStop: &v1.LifecycleHandler{ + Exec: &v1.ExecAction{ + Command: []string{ + "/bin/sh", "-c", "ray stop", + }, + }, + }, + }, + Resources: v1.ResourceRequirements{ + Requests: v1.ResourceList{ + "nvidia.com/gpu": resource.MustParse("1"), + }, + Limits: v1.ResourceList{ + "nvidia.com/gpu": resource.MustParse("1"), + }, + }, + }, + }, + }, + }, + ScaleStrategy: rayv1.ScaleStrategy{}, + }}, + EnableInTreeAutoscaling: nil, + AutoscalerOptions: nil, + HeadServiceAnnotations: nil, + }, + ClusterSelector: nil, + Suspend: false, + }, + } + return rayJobInstance, nil +} + +func generateLLMCheckpoint(checkpointNamespacedName types.NamespacedName, finetune *finetunev1beta1.Finetune, dataset *extensionv1beta1.Dataset, hyperparameter *corev1beta1.Hyperparameter, llm *corev1beta1.LLM) (*corev1beta1.LLMCheckpoint, error) { + if finetune.Status.LLMCheckpoint == nil || finetune.Status.LLMCheckpoint.CheckpointPath == "" { + return nil, fmt.Errorf("CheckpointPath is nil") + } + + var llmCheckpointInstance = &corev1beta1.LLMCheckpoint{ + TypeMeta: metav1.TypeMeta{ + Kind: "LLMCheckpoint", + APIVersion: "core.datatunerx.io/v1beta1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: checkpointNamespacedName.Name, + Namespace: checkpointNamespacedName.Namespace, + }, + Spec: corev1beta1.LLMCheckpointSpec{ + LLM: &corev1beta1.LLMRefSpec{ + LLMRef: finetune.Spec.LLM, + Spec: llm.Spec.DeepCopy(), + }, + Dataset: &corev1beta1.DatasetRefSpec{ + DatasetRef: finetune.Spec.Dataset, + Spec: dataset.Spec.DeepCopy(), + }, + Hyperparameter: &corev1beta1.HyperparameterRefSpec{ + HyperparameterRef: finetune.Spec.Hyperparameter.HyperparameterRef, + Spec: hyperparameter.Spec.DeepCopy(), + }, + Image: &finetune.Spec.Image, + Checkpoint: finetune.Status.LLMCheckpoint.CheckpointPath, + }, + } + return llmCheckpointInstance, nil +} + +func getFeaturesMapJson(features []extensionv1beta1.Feature) (string, error) { + if features == nil { + return "", nil + } + featuresMap := make(map[string]string) + for _, feature := range features { + if feature.Name == "instruction" { + featuresMap["instruction"] = feature.MapTo + continue + } + if feature.Name == "response" { + featuresMap["response"] = feature.MapTo + } + } + + if len(featuresMap) == 0 { + return "", nil + } + + jsonData, err := json.Marshal(featuresMap) + if err != nil { + return "", err + } + + return string(jsonData), nil +} + +func updateHyperparameters(parameters *corev1beta1.Parameters, overrides *finetunev1beta1.Parameters) *corev1beta1.Parameters { + newParameters := parameters.DeepCopy() + + if overrides == nil { + return newParameters + } + + if overrides.Scheduler != "" { + newParameters.Scheduler = overrides.Scheduler + } + + if overrides.Optimizer != "" { + newParameters.Optimizer = overrides.Optimizer + } + + if overrides.Int4 != nil { + newParameters.Int4 = *overrides.Int4 + } + + if overrides.Int8 != nil { + newParameters.Int8 = *overrides.Int8 + } + + if overrides.LoRA_R != nil { + newParameters.LoRA_R = *overrides.LoRA_R + } + + if overrides.LoRA_Alpha != nil { + newParameters.LoRA_Alpha = *overrides.LoRA_Alpha + } + + if overrides.LoRA_Dropout != nil { + newParameters.LoRA_Dropout = *overrides.LoRA_Dropout + } + + if overrides.LearningRate != nil { + newParameters.LearningRate = *overrides.LearningRate + } + + if overrides.Epochs != 0 { + newParameters.Epochs = overrides.Epochs + } + + if overrides.BlockSize != 0 { + newParameters.BlockSize = overrides.BlockSize + } + + if overrides.BatchSize != 0 { + newParameters.BatchSize = overrides.BatchSize + } + + if overrides.WarmupRatio != nil { + newParameters.WarmupRatio = *overrides.WarmupRatio + } + + if overrides.WeightDecay != nil { + newParameters.WeightDecay = *overrides.WeightDecay + } + + if overrides.GradAccSteps != 0 { + newParameters.GradAccSteps = overrides.GradAccSteps + } + + if overrides.TrainerType != nil { + newParameters.TrainerType = *overrides.TrainerType + } + + if overrides.PEFT != nil { + newParameters.PEFT = *overrides.PEFT + } + + if overrides.FP16 != nil { + newParameters.FP16 = *overrides.FP16 + } + + return newParameters +} + +// isJobPendingOrRunning indicates whether the job is running. +func isJobPendingOrRunning(status rayv1.JobStatus) bool { + return (status == rayv1.JobStatusPending) || (status == rayv1.JobStatusRunning) +} + +// isJobPendingOrRunning indicates whether the job is running. +func isJobStoppedOrFailed(status rayv1.JobStatus) bool { + return (status == rayv1.JobStatusStopped) || (status == rayv1.JobStatusFailed) +} + +// GenerateLLMCheckpointName generates a LLMCheckpoint name from Finetune name +func GenerateLLMCheckpointName(finetuneName string) string { + return fmt.Sprintf("%s-%s", finetuneName, rand.String(5)) +} + +// SetupWithManager sets up the controller with the Manager. +func (r *FinetuneReconciler) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&finetunev1beta1.Finetune{}). + Owns(&rayv1.RayJob{}). + Owns(&corev1beta1.LLMCheckpoint{}). + WithOptions(controller.Options{CacheSyncTimeout: 10 * time.Second}). + Complete(r) +} diff --git a/pkg/tuning/Dockerfile b/pkg/tuning/Dockerfile new file mode 100644 index 0000000..786fbfe --- /dev/null +++ b/pkg/tuning/Dockerfile @@ -0,0 +1,8 @@ +FROM rayproject/ray271-py39-gpu-llama2-7b-inference:20231220 + +WORKDIR /tuning + +COPY requirements.txt . +RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple + +COPY . . \ No newline at end of file diff --git a/pkg/tuning/README.md b/pkg/tuning/README.md new file mode 100644 index 0000000..48dbd0e --- /dev/null +++ b/pkg/tuning/README.md @@ -0,0 +1,3 @@ +# Dataset +instruction +response \ No newline at end of file diff --git a/pkg/tuning/build_image.sh b/pkg/tuning/build_image.sh new file mode 100644 index 0000000..26d3be0 --- /dev/null +++ b/pkg/tuning/build_image.sh @@ -0,0 +1 @@ +docker build . -t rayproject/ray271-llama2-7b-finetune:20231220 \ No newline at end of file diff --git a/pkg/tuning/callback.py b/pkg/tuning/callback.py new file mode 100644 index 0000000..cedf15a --- /dev/null +++ b/pkg/tuning/callback.py @@ -0,0 +1,166 @@ +import os +import json +import time +from typing import TYPE_CHECKING +from datetime import timedelta + +from transformers import TrainerCallback +from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR + +from prometheus.metrics import export_train_metrics, export_eval_metrics + +if TYPE_CHECKING: + from transformers import TrainingArguments, TrainerState, TrainerControl + +LOG_FILE_NAME = "trainer_log.jsonl" + + +class LogCallback(TrainerCallback): + + def __init__(self, runner=None, metrics_export_address=None, uid=None): + self.runner = runner + self.in_training = False + self.start_time = time.time() + self.cur_steps = 0 + self.max_steps = 0 + self.elapsed_time = "" + self.remaining_time = "" + self.metrics_export_address = metrics_export_address + self.uid = uid + + def timing(self): + cur_time = time.time() + elapsed_time = cur_time - self.start_time + avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0 + remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step + self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) + self.remaining_time = str(timedelta(seconds=int(remaining_time))) + + def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the beginning of training. + """ + if state.is_local_process_zero: + self.in_training = True + self.start_time = time.time() + self.max_steps = state.max_steps + if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir: + print("Previous log file in this folder will be deleted.") + os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) + + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of training. + """ + if state.is_local_process_zero: + self.in_training = False + self.cur_steps = 0 + self.max_steps = 0 + + def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of an substep during gradient accumulation. + """ + if state.is_local_process_zero and self.runner is not None and self.runner.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of a training step. + """ + if state.is_local_process_zero: + self.cur_steps = state.global_step + self.timing() + if self.runner is not None and self.runner.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after an evaluation phase. + """ + if state.is_local_process_zero and not self.in_training: + self.cur_steps = 0 + self.max_steps = 0 + + def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs): + r""" + Event called after a successful prediction. + """ + if state.is_local_process_zero and not self.in_training: + self.cur_steps = 0 + self.max_steps = 0 + + def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: + r""" + Event called after logging the last logs. + """ + if not state.is_local_process_zero: + return + + print('log_history: ', state.log_history[-1]) # add 看看返回的 key + if "eval_loss" in state.log_history[-1].keys(): + eval_log = dict( + uid=self.uid, + current_steps=self.cur_steps, + total_steps=self.max_steps, + eval_loss=state.log_history[-1].get("eval_loss", None), + eval_perplexity=state.log_history[-1].get("eval_perplexity", None), + eval_rouge_1=state.log_history[-1].get("eval_rouge-1", None), + eval_rouge_2=state.log_history[-1].get("eval_rouge-2", None), + eval_rouge_l=state.log_history[-1].get("eval_rouge-l", None), + eval_bleu_4=state.log_history[-1].get("eval_bleu-4", None), + epoch=state.log_history[-1].get("epoch", None), + percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, + elapsed_time=self.elapsed_time, + remaining_time=self.remaining_time + ) + else: + logs = dict( + uid=self.uid, + current_steps=self.cur_steps, + total_steps=self.max_steps, + loss=state.log_history[-1].get("loss", None), + eval_loss=state.log_history[-1].get("eval_loss", None), + val_perplexity=state.log_history[-1].get("eval_perplexity", None), + eval_rouge_1=state.log_history[-1].get("eval_rouge-1", None), + eval_rouge_2=state.log_history[-1].get("eval_rouge-2", None), + eval_rouge_l=state.log_history[-1].get("eval_rouge-l", None), + eval_bleu_4=state.log_history[-1].get("eval_bleu-4", None), + predict_loss=state.log_history[-1].get("predict_loss", None), + reward=state.log_history[-1].get("reward", None), + learning_rate=state.log_history[-1].get("learning_rate", None), + epoch=state.log_history[-1].get("epoch", None), + percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, + elapsed_time=self.elapsed_time, + remaining_time=self.remaining_time + ) + if self.runner is not None: + print("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( + logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0 + )) + + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(os.path.join(args.output_dir, 'watch'), exist_ok=True) + if "eval_loss" in state.log_history[-1].keys(): + with open(os.path.join(args.output_dir, 'watch', "eval_log.jsonl"), "a", encoding="utf-8") as f: + f.write(json.dumps(eval_log) + "\n") + if self.metrics_export_address: + export_eval_metrics(self.metrics_export_address, eval_log) + else: + with open(os.path.join(args.output_dir, 'watch', "trainer_log.jsonl"), "a", encoding="utf-8") as f: + f.write(json.dumps(logs) + "\n") + if self.metrics_export_address: + export_train_metrics(self.metrics_export_address, logs) + + def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after a prediction step. + """ + eval_dataloader = kwargs.pop("eval_dataloader", None) + if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training: + if self.max_steps == 0: + self.max_steps = len(eval_dataloader) + self.cur_steps += 1 + self.timing() diff --git a/pkg/tuning/ds_config.json b/pkg/tuning/ds_config.json new file mode 100644 index 0000000..5cf1c15 --- /dev/null +++ b/pkg/tuning/ds_config.json @@ -0,0 +1,13 @@ +{ + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "zero_allow_untested_optimizer": true, + "fp16": { + "enabled": "auto" + }, + "zero_optimization": { + "stage": 0 + } +} \ No newline at end of file diff --git a/pkg/tuning/parser.py b/pkg/tuning/parser.py new file mode 100644 index 0000000..a464712 --- /dev/null +++ b/pkg/tuning/parser.py @@ -0,0 +1,266 @@ +import json +import logging +from dataclasses import field, dataclass +from typing import Optional, Dict, Any, Tuple, Literal + +import transformers +from transformers import Seq2SeqTrainingArguments, HfArgumentParser + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + r""" + Arguments pertaining to which model/config/tokenizer we are going to fine-tune. + """ + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} + ) + use_fast_tokenizer: Optional[bool] = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} + ) + split_special_tokens: Optional[bool] = field( + default=False, + metadata={"help": "Whether or not the special tokens should be split during the tokenization process."} + ) + use_auth_token: Optional[bool] = field( + default=False, + metadata={"help": "Will use the token generated when running `huggingface-cli login`."} + ) + model_revision: Optional[str] = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} + ) + quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the model."} + ) + quantization_type: Optional[Literal["fp4", "nf4"]] = field( + default="nf4", + metadata={"help": "Quantization data type to use in int4 training."} + ) + double_quantization: Optional[bool] = field( + default=True, + metadata={"help": "Whether to use double quantization in int4 training or not."} + ) + quantization: Optional[str] = field( + default=None, + metadata={"help": "quantize the model, int4, int8, or None."} + ) + + rope_scaling: Optional[Literal["linear", "dynamic"]] = field( + default=None, + metadata={"help": "Adopt scaled rotary positional embeddings."} + ) + checkpoint_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} + ) + flash_attn: Optional[bool] = field( + default=False, + metadata={"help": "Enable FlashAttention-2 for faster training."} + ) + shift_attn: Optional[bool] = field( + default=False, + metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} + ) + reward_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory containing the checkpoints of the reward model."} + ) + plot_loss: Optional[bool] = field( + default=False, + metadata={"help": "Whether to plot the training loss after fine-tuning or not."} + ) + hf_auth_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with Hugging Face Hub."} + ) + export_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory to save the exported model."} + ) + + def __post_init__(self): + self.compute_dtype = None + self.model_max_length = None + + if self.split_special_tokens and self.use_fast_tokenizer: + raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") + + if self.checkpoint_dir is not None: # support merging multiple lora weights + self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] + + if self.quantization_bit is not None: + assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." + + if self.quantization is not None: + assert self.quantization in ["int4", "int8"], "We only accept int4 or int8 quantization." + + if self.use_auth_token == True and self.hf_auth_token is not None: + from huggingface_hub.hf_api import HfFolder # lazy load + HfFolder.save_token(self.hf_auth_token) + + +@dataclass +class FinetuningArguments: + r""" + Arguments pertaining to which techniques we are going to fine-tuning with. + """ + stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( + default="sft", + metadata={"help": "Which stage will be performed in training."} + ) + finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field( + default="lora", + metadata={"help": "Which fine-tuning method to use."} + ) + num_layer_trainable: Optional[int] = field( + default=3, + metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."} + ) + name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( + default="mlp", + metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ + LLaMA choices: [\"mlp\", \"self_attn\"], \ + BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \ + Qwen choices: [\"mlp\", \"attn\"], \ + Phi-1.5 choices: [\"mlp\", \"mixer\"], \ + LLaMA-2, Baichuan, InternLM, XVERSE choices: the same as LLaMA."} + ) + lora_rank: Optional[int] = field( + default=8, + metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} + ) + lora_alpha: Optional[float] = field( + default=32.0, + metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."} + ) + lora_dropout: Optional[float] = field( + default=0.1, + metadata={"help": "Dropout rate for the LoRA fine-tuning."} + ) + lora_target: Optional[str] = field( + default=None, + metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ + LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ + BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ + Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ + Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ + Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ + LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."} + ) + additional_target: Optional[str] = field( + default=None, + metadata={ + "help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."} + ) + resume_lora_training: Optional[bool] = field( + default=True, + metadata={ + "help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} + ) + ppo_score_norm: Optional[bool] = field( + default=False, + metadata={"help": "Use score normalization in PPO training."} + ) + ppo_logger: Optional[str] = field( + default=None, + metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."} + ) + ppo_target: Optional[float] = field( + default=6.0, + metadata={"help": "Target KL value for adaptive KL control in PPO training."} + ) + dpo_beta: Optional[float] = field( + default=0.1, + metadata={"help": "The beta parameter for the DPO loss."} + ) + upcast_layernorm: Optional[bool] = field( + default=False, + metadata={"help": "Whether to upcast the layernorm weights in fp32."} + ) + neft_alpha: Optional[float] = field( + default=0, + metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."} + ) + num_workers: Optional[int] = field( + default=1, + metadata={"help": "Number of worker."} + ) + storage_path: Optional[str] = field( + default=None, + metadata={"help": "storage_path is used to storage checkpoint."} + ) + metrics_export_address: Optional[str] = field( + default=None, + metadata={"help": "address to export train metrics."} + ) + uid: Optional[str] = field( + default=None, + metadata={"help": "finetune crd uid."} + ) + + def __post_init__(self): + if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA + self.lora_target = [target.strip() for target in self.lora_target.split(",")] + + if isinstance(self.additional_target, str): + self.additional_target = [target.strip() for target in self.additional_target.split(",")] + + assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method." + + if not self.storage_path: + raise ValueError("--storage_path must be specified") + + +@dataclass +class DataArguments: + train_path: Optional[str] = field( + default=None, + metadata={"help": "Path to train dataset"} + ) + + evaluation_path: Optional[str] = field( + default=None, + metadata={"help": "Path to evaluation dataset"} + ) + + columns: Optional[str] = field( + default=None, + metadata={"help": "columns map for dataset"} + ) + block_size: Optional[int] = field( + default=1024, + metadata={"help": "length of input."} + ) + + def __post_init__(self): + if self.train_path is None: + raise ValueError("--train_path must be specified") + + +def get_train_args() -> Tuple[Seq2SeqTrainingArguments, FinetuningArguments, ModelArguments, DataArguments]: + parser = HfArgumentParser((Seq2SeqTrainingArguments, FinetuningArguments, ModelArguments, DataArguments)) + + training_args, finetuning_args, model_args, data_args = parser.parse_args_into_dataclasses() + + training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning + + # Log on each process the small summary: + logger.info( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" + + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + + # Set seed before initializing model. + transformers.set_seed(training_args.seed) + + return training_args, finetuning_args, model_args, data_args diff --git a/pkg/tuning/prometheus/__init__.py b/pkg/tuning/prometheus/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkg/tuning/prometheus/metrics.py b/pkg/tuning/prometheus/metrics.py new file mode 100644 index 0000000..4e5f794 --- /dev/null +++ b/pkg/tuning/prometheus/metrics.py @@ -0,0 +1,135 @@ +from typing import List, Dict +from datetime import datetime +from urllib.parse import urljoin +from .prometheus_pb2 import ( + WriteRequest, + TimeSeries +) +import calendar +import logging +import requests +import snappy + + +def dt2ts(dt): + """Converts a datetime object to UTC timestamp + naive datetime will be considered UTC. + """ + return calendar.timegm(dt.utctimetuple()) + + +def write(address: str, series: List[TimeSeries]): + write_request = WriteRequest() + write_request.timeseries.extend(series) + + uncompressed = write_request.SerializeToString() + compressed = snappy.compress(uncompressed) + + url = urljoin(address, "/api/v1/write") + headers = { + "Content-Encoding": "snappy", + "Content-Type": "application/x-protobuf", + "X-Prometheus-Remote-Write-Version": "0.1.0", + "User-Agent": "metrics-worker" + } + try: + response = requests.post(url, headers=headers, data=compressed) + print(response) + except Exception as e: + print(e) + + +def export_train_metrics(address: str, metrics: Dict): + series = TimeSeries() + label = series.labels.add() + label.name = "__name__" + label.value = "train_metrics" + + label = series.labels.add() + label.name = "uid" + label.value = str(metrics["uid"]) + + label = series.labels.add() + label.name = "total_steps" + label.value = str(metrics.get("total_steps", "")) + + label = series.labels.add() + label.name = "current_steps" + label.value = str(metrics.get("current_steps", "")) + + label = series.labels.add() + label.name = "loss" + label.value = str(metrics.get("loss", "")) + + label = series.labels.add() + label.name = "learning_rate" + label.value = str(metrics.get("learning_rate", "")) + + label = series.labels.add() + label.name = "epoch" + label.value = str(metrics.get("epoch", "")) + + sample = series.samples.add() + sample.value = 1 + sample.timestamp = dt2ts(datetime.utcnow()) * 1000 + + write(address, [series]) + + +def export_eval_metrics(address: str, metrics: Dict): + series = TimeSeries() + label = series.labels.add() + label.name = "__name__" + label.value = "eval_metrics" + + label = series.labels.add() + label.name = "uid" + label.value = str(metrics["uid"]) + + label = series.labels.add() + label.name = "total_steps" + label.value = str(metrics.get("total_steps", "")) + + label = series.labels.add() + label.name = "current_steps" + label.value = str(metrics.get("current_steps", "")) + + label = series.labels.add() + label.name = "eval_loss" + label.value = str(metrics.get("eval_loss", "")) + + label = series.labels.add() + label.name = "eval_perplexity" + label.value = str(metrics.get("eval_perplexity", "")) + + label = series.labels.add() + label.name = "epoch" + label.value = str(metrics.get("epoch", "")) + + sample = series.samples.add() + sample.value = 1 + sample.timestamp = dt2ts(datetime.utcnow()) * 1000 + + write(address, [series]) + + +if __name__ == '__main__': + train_metrics = { + "uid": "1", + "current_steps": 10, + "total_steps": 84, + "loss": 3.088, + "learning_rate": 4.404761904761905e-05, + "epoch": 0.71 + } + export_train_metrics("http://10.33.1.10:30722", train_metrics) + + eval_metrics = { + "uid": "1", + "current_steps": 10, + "total_steps": 84, + "eval_loss": 3.088, + "eval_perplexity": 4.404761904761905e-05, + "epoch": 0.71 + } + export_eval_metrics("http://10.33.1.10:30722", eval_metrics) diff --git a/pkg/tuning/prometheus/prometheus.proto b/pkg/tuning/prometheus/prometheus.proto new file mode 100644 index 0000000..38eed5b --- /dev/null +++ b/pkg/tuning/prometheus/prometheus.proto @@ -0,0 +1,81 @@ +// Copyright 2016 Prometheus Team +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; +package prometheus; + +option go_package = "prompb"; + +message WriteRequest { + repeated prometheus.TimeSeries timeseries = 1; +} + +message ReadRequest { + repeated Query queries = 1; +} + +message ReadResponse { + // In same order as the request's queries. + repeated QueryResult results = 1; +} + +message Query { + int64 start_timestamp_ms = 1; + int64 end_timestamp_ms = 2; + repeated prometheus.LabelMatcher matchers = 3; + prometheus.ReadHints hints = 4; +} + +message QueryResult { + // Samples within a time series must be ordered by time. + repeated prometheus.TimeSeries timeseries = 1; +} + +message Sample { + double value = 1; + int64 timestamp = 2; +} + +message TimeSeries { + repeated Label labels = 1; + repeated Sample samples = 2; +} + +message Label { + string name = 1; + string value = 2; +} + +message Labels { + repeated Label labels = 1; +} + +// Matcher specifies a rule, which can match or set of labels or not. +message LabelMatcher { + enum Type { + EQ = 0; + NEQ = 1; + RE = 2; + NRE = 3; + } + Type type = 1; + string name = 2; + string value = 3; +} + +message ReadHints { + int64 step_ms = 1; // Query step size in milliseconds. + string func = 2; // String representation of surrounding function or aggregation. + int64 start_ms = 3; // Start time in milliseconds. + int64 end_ms = 4; // End time in milliseconds. +} \ No newline at end of file diff --git a/pkg/tuning/prometheus/prometheus_pb2.py b/pkg/tuning/prometheus/prometheus_pb2.py new file mode 100644 index 0000000..d51c3eb --- /dev/null +++ b/pkg/tuning/prometheus/prometheus_pb2.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: prometheus.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10prometheus.proto\x12\nprometheus\":\n\x0cWriteRequest\x12*\n\ntimeseries\x18\x01 \x03(\x0b\x32\x16.prometheus.TimeSeries\"1\n\x0bReadRequest\x12\"\n\x07queries\x18\x01 \x03(\x0b\x32\x11.prometheus.Query\"8\n\x0cReadResponse\x12(\n\x07results\x18\x01 \x03(\x0b\x32\x17.prometheus.QueryResult\"\x8f\x01\n\x05Query\x12\x1a\n\x12start_timestamp_ms\x18\x01 \x01(\x03\x12\x18\n\x10\x65nd_timestamp_ms\x18\x02 \x01(\x03\x12*\n\x08matchers\x18\x03 \x03(\x0b\x32\x18.prometheus.LabelMatcher\x12$\n\x05hints\x18\x04 \x01(\x0b\x32\x15.prometheus.ReadHints\"9\n\x0bQueryResult\x12*\n\ntimeseries\x18\x01 \x03(\x0b\x32\x16.prometheus.TimeSeries\"*\n\x06Sample\x12\r\n\x05value\x18\x01 \x01(\x01\x12\x11\n\ttimestamp\x18\x02 \x01(\x03\"T\n\nTimeSeries\x12!\n\x06labels\x18\x01 \x03(\x0b\x32\x11.prometheus.Label\x12#\n\x07samples\x18\x02 \x03(\x0b\x32\x12.prometheus.Sample\"$\n\x05Label\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"+\n\x06Labels\x12!\n\x06labels\x18\x01 \x03(\x0b\x32\x11.prometheus.Label\"\x82\x01\n\x0cLabelMatcher\x12+\n\x04type\x18\x01 \x01(\x0e\x32\x1d.prometheus.LabelMatcher.Type\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\t\"(\n\x04Type\x12\x06\n\x02\x45Q\x10\x00\x12\x07\n\x03NEQ\x10\x01\x12\x06\n\x02RE\x10\x02\x12\x07\n\x03NRE\x10\x03\"L\n\tReadHints\x12\x0f\n\x07step_ms\x18\x01 \x01(\x03\x12\x0c\n\x04\x66unc\x18\x02 \x01(\t\x12\x10\n\x08start_ms\x18\x03 \x01(\x03\x12\x0e\n\x06\x65nd_ms\x18\x04 \x01(\x03\x42\x08Z\x06prompbb\x06proto3') + + + +_WRITEREQUEST = DESCRIPTOR.message_types_by_name['WriteRequest'] +_READREQUEST = DESCRIPTOR.message_types_by_name['ReadRequest'] +_READRESPONSE = DESCRIPTOR.message_types_by_name['ReadResponse'] +_QUERY = DESCRIPTOR.message_types_by_name['Query'] +_QUERYRESULT = DESCRIPTOR.message_types_by_name['QueryResult'] +_SAMPLE = DESCRIPTOR.message_types_by_name['Sample'] +_TIMESERIES = DESCRIPTOR.message_types_by_name['TimeSeries'] +_LABEL = DESCRIPTOR.message_types_by_name['Label'] +_LABELS = DESCRIPTOR.message_types_by_name['Labels'] +_LABELMATCHER = DESCRIPTOR.message_types_by_name['LabelMatcher'] +_READHINTS = DESCRIPTOR.message_types_by_name['ReadHints'] +_LABELMATCHER_TYPE = _LABELMATCHER.enum_types_by_name['Type'] +WriteRequest = _reflection.GeneratedProtocolMessageType('WriteRequest', (_message.Message,), { + 'DESCRIPTOR' : _WRITEREQUEST, + '__module__' : 'prometheus_pb2' + # @@protoc_insertion_point(class_scope:prometheus.WriteRequest) + }) +_sym_db.RegisterMessage(WriteRequest) + +ReadRequest = _reflection.GeneratedProtocolMessageType('ReadRequest', (_message.Message,), { + 'DESCRIPTOR' : _READREQUEST, + '__module__' : 'prometheus_pb2' + # @@protoc_insertion_point(class_scope:prometheus.ReadRequest) + }) +_sym_db.RegisterMessage(ReadRequest) + +ReadResponse = _reflection.GeneratedProtocolMessageType('ReadResponse', (_message.Message,), { + 'DESCRIPTOR' : _READRESPONSE, + '__module__' : 'prometheus_pb2' + # @@protoc_insertion_point(class_scope:prometheus.ReadResponse) + }) +_sym_db.RegisterMessage(ReadResponse) + +Query = _reflection.GeneratedProtocolMessageType('Query', (_message.Message,), { + 'DESCRIPTOR' : _QUERY, + '__module__' : 'prometheus_pb2' + # @@protoc_insertion_point(class_scope:prometheus.Query) + }) +_sym_db.RegisterMessage(Query) + +QueryResult = _reflection.GeneratedProtocolMessageType('QueryResult', (_message.Message,), { + 'DESCRIPTOR' : _QUERYRESULT, + '__module__' : 'prometheus_pb2' + # @@protoc_insertion_point(class_scope:prometheus.QueryResult) + }) +_sym_db.RegisterMessage(QueryResult) + +Sample = _reflection.GeneratedProtocolMessageType('Sample', (_message.Message,), { + 'DESCRIPTOR' : _SAMPLE, + '__module__' : 'prometheus_pb2' + # @@protoc_insertion_point(class_scope:prometheus.Sample) + }) +_sym_db.RegisterMessage(Sample) + +TimeSeries = _reflection.GeneratedProtocolMessageType('TimeSeries', (_message.Message,), { + 'DESCRIPTOR' : _TIMESERIES, + '__module__' : 'prometheus_pb2' + # @@protoc_insertion_point(class_scope:prometheus.TimeSeries) + }) +_sym_db.RegisterMessage(TimeSeries) + +Label = _reflection.GeneratedProtocolMessageType('Label', (_message.Message,), { + 'DESCRIPTOR' : _LABEL, + '__module__' : 'prometheus_pb2' + # @@protoc_insertion_point(class_scope:prometheus.Label) + }) +_sym_db.RegisterMessage(Label) + +Labels = _reflection.GeneratedProtocolMessageType('Labels', (_message.Message,), { + 'DESCRIPTOR' : _LABELS, + '__module__' : 'prometheus_pb2' + # @@protoc_insertion_point(class_scope:prometheus.Labels) + }) +_sym_db.RegisterMessage(Labels) + +LabelMatcher = _reflection.GeneratedProtocolMessageType('LabelMatcher', (_message.Message,), { + 'DESCRIPTOR' : _LABELMATCHER, + '__module__' : 'prometheus_pb2' + # @@protoc_insertion_point(class_scope:prometheus.LabelMatcher) + }) +_sym_db.RegisterMessage(LabelMatcher) + +ReadHints = _reflection.GeneratedProtocolMessageType('ReadHints', (_message.Message,), { + 'DESCRIPTOR' : _READHINTS, + '__module__' : 'prometheus_pb2' + # @@protoc_insertion_point(class_scope:prometheus.ReadHints) + }) +_sym_db.RegisterMessage(ReadHints) + +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'Z\006prompb' + _WRITEREQUEST._serialized_start=32 + _WRITEREQUEST._serialized_end=90 + _READREQUEST._serialized_start=92 + _READREQUEST._serialized_end=141 + _READRESPONSE._serialized_start=143 + _READRESPONSE._serialized_end=199 + _QUERY._serialized_start=202 + _QUERY._serialized_end=345 + _QUERYRESULT._serialized_start=347 + _QUERYRESULT._serialized_end=404 + _SAMPLE._serialized_start=406 + _SAMPLE._serialized_end=448 + _TIMESERIES._serialized_start=450 + _TIMESERIES._serialized_end=534 + _LABEL._serialized_start=536 + _LABEL._serialized_end=572 + _LABELS._serialized_start=574 + _LABELS._serialized_end=617 + _LABELMATCHER._serialized_start=620 + _LABELMATCHER._serialized_end=750 + _LABELMATCHER_TYPE._serialized_start=710 + _LABELMATCHER_TYPE._serialized_end=750 + _READHINTS._serialized_start=752 + _READHINTS._serialized_end=828 +# @@protoc_insertion_point(module_scope) diff --git a/pkg/tuning/requirements.txt b/pkg/tuning/requirements.txt new file mode 100644 index 0000000..4d8fc9d --- /dev/null +++ b/pkg/tuning/requirements.txt @@ -0,0 +1,10 @@ +bitsandbytes==0.41.3.post2 +datasets==2.14.5 +deepspeed==0.12.2 +evaluate==0.4.1 +peft==0.5.0 +protobuf==3.19.6 +python-snappy==0.6.1 +torch==2.1.0 +transformers==4.34.0 + diff --git a/pkg/tuning/template.py b/pkg/tuning/template.py new file mode 100644 index 0000000..f06a772 --- /dev/null +++ b/pkg/tuning/template.py @@ -0,0 +1,620 @@ +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + +logger = logging.getLogger(__name__) + + +@dataclass +class Template: + + prefix: List[Union[str, Dict[str, str]]] + prompt: List[Union[str, Dict[str, str]]] + system: str + sep: List[Union[str, Dict[str, str]]] + stop_words: List[str] + use_history: bool + efficient_eos: bool + + def encode_oneturn( + self, + tokenizer: "PreTrainedTokenizer", + query: str, + resp: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None + ) -> Tuple[List[int], List[int]]: + r""" + Returns a single pair of token ids representing prompt and response respectively. + """ + system, history = self._format(query, resp, history, system) + encoded_pairs = self._encode(tokenizer, system, history) + prompt_ids = [] + for query_ids, resp_ids in encoded_pairs[:-1]: + prompt_ids = prompt_ids + query_ids + resp_ids + prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1] + return prompt_ids, answer_ids + + def encode_multiturn( + self, + tokenizer: "PreTrainedTokenizer", + query: str, + resp: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None + ) -> List[Tuple[List[int], List[int]]]: + r""" + Returns multiple pairs of token ids representing prompts and responses respectively. + """ + system, history = self._format(query, resp, history, system) + encoded_pairs = self._encode(tokenizer, system, history) + return encoded_pairs + + def _format( + self, + query: str, + resp: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None + ) -> Tuple[str, List[Tuple[str, str]]]: + r""" + Aligns inputs to the standard format. + """ + system = system or self.system # use system if provided + history = history if (history and self.use_history) else [] + history = history + [(query, resp)] + return system, history + + def _get_special_ids( + self, + tokenizer: "PreTrainedTokenizer" + ) -> Tuple[List[int], List[int]]: + if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True): + bos_ids = [tokenizer.bos_token_id] + else: # baichuan, qwen and gpt2 models have no bos token + bos_ids = [] + + if tokenizer.eos_token_id is None: + raise ValueError("EOS token is required.") + + if self.efficient_eos: # used in baichuan, qwen, chatglm, etc. + eos_ids = [] + else: + eos_ids = [tokenizer.eos_token_id] + + return bos_ids, eos_ids + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + system: str, + history: List[Tuple[str, str]] + ) -> List[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: bos + prefix + sep + query resp + eos + Turn t: sep + bos + query resp + eos + """ + bos_ids, eos_ids = self._get_special_ids(tokenizer) + sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) + encoded_pairs = [] + for turn_idx, (query, resp) in enumerate(history): + if turn_idx == 0: + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system) + if len(prefix_ids) != 0: # has prefix + prefix_ids = bos_ids + prefix_ids + sep_ids + else: + prefix_ids = bos_ids + else: + prefix_ids = sep_ids + bos_ids + + query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx)) + resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) + encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids)) + return encoded_pairs + + def _convert_inputs_to_ids( + self, + tokenizer: "PreTrainedTokenizer", + context: List[Union[str, Dict[str, str]]], + system: Optional[str] = None, + query: Optional[str] = None, + idx: Optional[str] = None + ) -> List[int]: + r""" + Converts context to token ids. + """ + + kwargs = dict(add_special_tokens=False) + + token_ids = [] + for elem in context: + if isinstance(elem, str): + elem = elem.replace("{{system}}", system, 1) if system is not None else elem + elem = elem.replace("{{query}}", query, 1) if query is not None else elem + elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem + if len(elem) != 0: + token_ids = token_ids + tokenizer.encode(elem, **kwargs) + elif isinstance(elem, dict): + token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] + else: + raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem))) + + return token_ids + + +@dataclass +class Llama2Template(Template): + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + system: str, + history: List[Tuple[str, str]] + ) -> List[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: bos + prefix + query resp + eos + Turn t: bos + query resp + eos + """ + bos_ids, eos_ids = self._get_special_ids(tokenizer) + encoded_pairs = [] + for turn_idx, (query, resp) in enumerate(history): + if turn_idx == 0: # llama2 template has no sep_ids + query = self.prefix[0].replace("{{system}}", system) + query + query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) + resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) + encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids)) + return encoded_pairs + + +templates: Dict[str, Template] = {} + + +def register_template( + name: str, + prefix: List[Union[str, Dict[str, str]]], + prompt: List[Union[str, Dict[str, str]]], + system: str, + sep: List[Union[str, Dict[str, str]]], + stop_words: Optional[List[str]] = [], + use_history: Optional[bool] = True, + efficient_eos: Optional[bool] = False +) -> None: + template_class = Llama2Template if "llama2" in name else Template + templates[name] = template_class( + prefix=prefix, + prompt=prompt, + system=system, + sep=sep, + stop_words=stop_words, + use_history=use_history, + efficient_eos=efficient_eos + ) + + +def get_template_and_fix_tokenizer( + name: str, + tokenizer: "PreTrainedTokenizer" +) -> Template: + if tokenizer.eos_token_id is None: + tokenizer.eos_token = "<|endoftext|>" + logger.info("Add eos token: {}".format(tokenizer.eos_token)) + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + logger.info("Add pad token: {}".format(tokenizer.pad_token)) + + if name is None: + return None + + template = templates.get(name, None) + assert template is not None, "Template {} does not exist.".format(name) + tokenizer.add_special_tokens( + dict(additional_special_tokens=template.stop_words), + replace_additional_special_tokens=False + ) + return template + + +r""" +Supports language model inference without histories. +""" +register_template( + name="vanilla", + prefix=[], + prompt=[ + "{{query}}" + ], + system="", + sep=[], + use_history=False +) + + +r""" +Default template. +""" +register_template( + name="default", + prefix=[ + "{{system}}" + ], + prompt=[ + "Human: {{query}}\nAssistant: " + ], + system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + sep=[ + "\n" + ] +) + + +r""" +Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf + https://huggingface.co/meta-llama/Llama-2-13b-chat-hf + https://huggingface.co/meta-llama/Llama-2-70b-chat-hf +""" +register_template( + name="llama2", + prefix=[ + "<>\n{{system}}\n<>\n\n" + ], + prompt=[ + "[INST] {{query}} [/INST] " + ], + system=( + "You are a helpful, respectful and honest assistant. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, " + "racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n\n" + "If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information." + ), + sep=[] +) + + +r""" +Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b + https://huggingface.co/ziqingyang/chinese-alpaca-2-13b +""" +register_template( + name="llama2_zh", + prefix=[ + "<>\n{{system}}\n<>\n\n" + ], + prompt=[ + "[INST] {{query}} [/INST] " + ], + system="You are a helpful assistant. 你是一个乐于助人的助手。", + sep=[] +) + + +r""" +Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff +""" +register_template( + name="alpaca", + prefix=[ + "{{system}}" + ], + prompt=[ + "### Instruction:\n{{query}}\n\n### Response:\n" + ], + system=( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request." + ), + sep=[ + "\n\n" + ] +) + + +r""" +Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5 + https://huggingface.co/lmsys/vicuna-13b-v1.5 +""" +register_template( + name="vicuna", + prefix=[ + "{{system}}" + ], + prompt=[ + "USER: {{query}} ASSISTANT:" + ], + system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + sep=[] +) + + +r""" +Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B +""" +register_template( + name="belle", + prefix=[ + "{{system}}" + ], + prompt=[ + "Human: {{query}}\n\nBelle: " + ], + system="", + sep=[ + "\n\n" + ] +) + + +r""" +Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 + https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1 + https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat +""" +register_template( + name="ziya", + prefix=[ + "{{system}}" + ], + prompt=[ + {"token": ""}, + ":{{query}}\n", + {"token": ""}, + ":" + ], + system="", + sep=[ + "\n" + ] +) + + +r""" +Supports: https://huggingface.co/BAAI/AquilaChat-7B + https://huggingface.co/BAAI/AquilaChat2-7B + https://huggingface.co/BAAI/AquilaChat2-34B +""" +register_template( + name="aquila", + prefix=[ + "{{system}}" + ], + prompt=[ + "Human: {{query}}###Assistant:" + ], + system=( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions." + ), + sep=[ + "###" + ], + stop_words=[ + "" + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/internlm/internlm-chat-7b + https://huggingface.co/internlm/internlm-chat-20b +""" +register_template( + name="intern", + prefix=[ + "{{system}}" + ], + prompt=[ + "<|User|>:{{query}}", + {"token": ""}, + "\n<|Bot|>:" + ], + system="", + sep=[ + {"token": ""}, + "\n" + ], + stop_words=[ + "" + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat +""" +register_template( + name="baichuan", + prefix=[ + "{{system}}" + ], + prompt=[ + {"token": ""}, # user token + "{{query}}", + {"token": ""} # assistant token + ], + system="", + sep=[], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat + https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat +""" +register_template( + name="baichuan2", + prefix=[ + "{{system}}" + ], + prompt=[ + {"token": ""}, # user token + "{{query}}", + {"token": ""} # assistant token + ], + system="", + sep=[], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha + https://huggingface.co/HuggingFaceH4/starchat-beta +""" +register_template( + name="starchat", + prefix=[ + {"token": "<|system|>"}, + "\n{{system}}", + ], + prompt=[ + {"token": "<|user|>"}, + "\n{{query}}", + {"token": "<|end|>"}, + "\n", + {"token": "<|assistant|>"} + ], + system="", + sep=[ + {"token": "<|end|>"}, + "\n" + ], + stop_words=[ + "<|end|>" + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/Qwen/Qwen-7B-Chat + https://huggingface.co/Qwen/Qwen-14B-Chat +""" +register_template( + name="chatml", + prefix=[ + {"token": "<|im_start|>"}, + "system\n{{system}}" + ], + prompt=[ + {"token": "<|im_start|>"}, + "user\n{{query}}", + {"token": "<|im_end|>"}, + "\n", + {"token": "<|im_start|>"}, + "assistant\n" + ], + system="You are a helpful assistant.", + sep=[ + {"token": "<|im_end|>"}, + "\n" + ], + stop_words=[ + "<|im_end|>" + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/THUDM/chatglm2-6b +""" +register_template( + name="chatglm2", + prefix=[ + {"token": "[gMASK]"}, + {"token": "sop"}, + "{{system}}" + ], + prompt=[ + "[Round {{idx}}]\n\n问:{{query}}\n\n答:" + ], + system="", + sep=[ + "\n\n" + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/THUDM/chatglm3-6b +""" +register_template( + name="chatglm3", + prefix=[ + {"token": "[gMASK]"}, + {"token": "sop"}, + "{{system}}" + ], + prompt=[ + {"token": "<|user|>"}, + "\n", + "{{query}}", + {"token": "<|assistant|>"} + ], + system="", + sep=[], + stop_words=[ + "<|user|>", + "<|observation|>" + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/openchat/openchat_v3.2_super +""" +register_template( + name="openchat", + prefix=[ + "{{system}}" + ], + prompt=[ + "GPT4 User: {{query}}", + {"token": "<|end_of_turn|>"}, + "GPT4 Assistant:" + ], + system="", + sep=[ + {"token": "<|end_of_turn|>"} + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/xverse/XVERSE-7B-Chat + https://huggingface.co/xverse/XVERSE-13B-Chat +""" +register_template( + name="xverse", + prefix=[ + "{{system}}" + ], + prompt=[ + "Human: {{query}}\n\nAssistant: " + ], + system="", + sep=[] +) diff --git a/pkg/tuning/train.py b/pkg/tuning/train.py new file mode 100644 index 0000000..cc4c28d --- /dev/null +++ b/pkg/tuning/train.py @@ -0,0 +1,393 @@ +import functools +import json +import logging +import math +import os +import sys +from types import MethodType + +import numpy as np +from dataclasses import dataclass +from typing import Sequence, Union, Tuple, Dict, List, Any, Generator, Optional + +import ray +import ray.data +import torch +import evaluate +from pandas import DataFrame +from peft import PeftModel, LoraConfig, TaskType, get_peft_model +from ray.air import ScalingConfig, RunConfig +from ray.train import Checkpoint +from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback +from ray.train.torch import TorchTrainer, get_device, TorchConfig +from torch import nn +from transformers import AutoTokenizer, AutoModel, Trainer, TrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq, \ + AutoConfig, AutoModelForCausalLM, Seq2SeqTrainingArguments, BitsAndBytesConfig +from ray.train.huggingface import TransformersTrainer +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.tokenization_utils import PreTrainedTokenizer +from datasets import load_dataset, Dataset + +from callback import LogCallback +from parser import get_train_args +from template import get_template_and_fix_tokenizer +from trainer import SFTTrainer + +logging.basicConfig(level=logging.INFO) +# logger = logging.getLogger(__name__) +# formatter = logging.Formatter( +# fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", +# datefmt="%m/%d/%Y %H:%M:%S" +# ) +# handler = logging.StreamHandler(sys.stdout) +# handler.setFormatter(formatter) +# +# logger = logging.getLogger() +# logger.setLevel(logging.INFO) +# logger.addHandler(handler) + +cpus_per_worker = 8 +IGNORE_INDEX = -100 +cutoff_len = 1024 + + +def rename_columns(batch: DataFrame, columns): + return batch.rename(columns=columns) + + +def preprocess_dataset( + dataset: Union["Dataset", "IterableDataset"], + tokenizer: "PreTrainedTokenizer", + training_args: "Seq2SeqTrainingArguments" +) -> Union["Dataset", "IterableDataset"]: + template = get_template_and_fix_tokenizer("llama2", tokenizer) + + def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: + for i in range(len(examples["instruction"])): + query, response = examples["instruction"][i], examples["response"][i] + query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query + history = examples["history"][i] if "history" in examples else None + system = examples["system"][i] if "system" in examples else None + yield query, response, history, system + + def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # for multiturn examples, we only mask the prompt part in each prompt-response pair. + # print(examples) + + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + + for query, response, history, system in construct_example(examples): + if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""): + continue + + input_ids, labels = [], [] + for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( + tokenizer, query, response, history, system + )): + total_len = len(source_ids) + len(target_ids) + max_source_len = int(cutoff_len * (len(source_ids) / total_len)) + max_target_len = int(cutoff_len * (len(target_ids) / total_len)) + + if len(source_ids) > max_source_len: + source_ids = source_ids[:max_source_len] + if len(target_ids) > max_target_len: + target_ids = target_ids[:max_target_len] + + if turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + + input_ids += source_ids + target_ids + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + if len(input_ids) > cutoff_len: + input_ids = input_ids[:cutoff_len] + labels = labels[:cutoff_len] + + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + + return model_inputs + + def print_supervised_dataset_example(example): + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print("labels:\n{}".format( + tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) + )) + + preprocess_func = preprocess_supervised_dataset + print_function = print_supervised_dataset_example + new_dataset = dataset.map_batches(preprocess_func) + if training_args.should_log: + try: + print_function(new_dataset.take(1)[0]) + except StopIteration: + raise RuntimeError("Empty dataset!") + return new_dataset + + +def trainer_init_per_worker(config): + print("--- train_task, pid: ", os.getpid()) + + cuda_visible_device = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + print("CUDA_VISIBLE_DEVICES", os.environ["CUDA_VISIBLE_DEVICES"]) + local_rank = int(os.environ["LOCAL_RANK"]) + print("local_rank:", local_rank) + device_id = cuda_visible_device[local_rank] + print("device_id:", device_id) + os.environ["CUDA_VISIBLE_DEVICES"] = f"{device_id}" + torch.cuda.set_device(int(device_id)) + + # device setting + # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print("device:", torch.cuda.current_device()) + device_ids = torch._utils._get_all_device_indices() + print("device_ids:", device_ids) + if len(device_ids) <= 0: + print("invalid device_ids, exit") + return + + training_args = config.get("training_args", None) + finetuning_args = config.get("finetuning_args", None) + model_args = config.get("model_args", None) + data_args = config.get("data_args", None) + tokenizer = config.get("tokenizer", None) + + # read dataset + train_ds = ray.train.get_dataset_shard("train") + print(f"train_ds: {train_ds}") + + def train_gen(): + for row in train_ds.iter_rows(): + yield row + + train_dataset = Dataset.from_generator(train_gen) + print(train_dataset) + print('------') + print(train_dataset[0]) + + eval_ds = ray.train.get_dataset_shard("evaluation") + print(f"eval_ds: {eval_ds}") + + def eval_gen(): + for row in eval_ds.iter_rows(): + yield row + + eval_dataset = None + evaluation_strategy = "no" + if eval_ds: + eval_dataset = Dataset.from_generator(eval_gen) + print(eval_dataset) + evaluation_strategy = "steps" + + train_ds_len = len(list(train_ds.iter_batches(batch_size=1))) + steps_per_epoch = math.ceil(train_ds_len / training_args.per_device_train_batch_size) + print(f"train_ds_len: {train_ds_len}, steps_per_epoch: {steps_per_epoch}") + + new_training_args = Seq2SeqTrainingArguments( + training_args.output_dir, + logging_steps=10, + save_strategy="no", + evaluation_strategy=evaluation_strategy, + num_train_epochs=training_args.num_train_epochs, + learning_rate=training_args.learning_rate, + weight_decay=training_args.weight_decay, + warmup_steps=training_args.warmup_steps, + per_device_train_batch_size=training_args.per_device_train_batch_size, + per_device_eval_batch_size=training_args.per_device_eval_batch_size, + optim=training_args.optim, + lr_scheduler_type=training_args.lr_scheduler_type, + gradient_accumulation_steps=training_args.gradient_accumulation_steps, + push_to_hub=False, + report_to="none", + disable_tqdm=False, # declutter the output a little + fp16=training_args.fp16, + gradient_checkpointing=True, + deepspeed=training_args.deepspeed, + log_level="info", + ) + + print(f"new_training_args: {new_training_args}".replace("\n", " ")) + + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + compute_dtype = getattr(config, "torch_dtype", None) + + if model_args.quantization == "int4": + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=False, + ) + elif model_args.quantization == "int8": + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + else: + quantization_config = None + + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + torch_dtype=compute_dtype, + quantization_config=quantization_config, + low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), + ) + + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model.gradient_checkpointing_enable() + model.config.use_cache = False # turn off when gradient checkpointing is enabled + print("Gradient checkpointing enabled.") + + output_layer_name = "lm_head" + + if hasattr(model, output_layer_name): + output_layer = getattr(model, output_layer_name) + if isinstance(output_layer, torch.nn.Linear): + def forward_in_fp32(self, x: torch.Tensor) -> torch.Tensor: + return output_layer.__class__.forward(self, x.to(output_layer.weight.dtype)).to(torch.float32) + + output_layer.forward = MethodType(forward_in_fp32, output_layer) + + target_modules = finetuning_args.lora_target + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=finetuning_args.lora_rank, + lora_alpha=finetuning_args.lora_alpha, + lora_dropout=finetuning_args.lora_dropout, + target_modules=target_modules, + modules_to_save=finetuning_args.additional_target + ) + model = get_peft_model(model, lora_config) + if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923 + model.base_model.peft_config = model.peft_config + model.train() + + data_collator = DataCollatorForSeq2Seq( + tokenizer=tokenizer, + pad_to_multiple_of=4, # for shift short attention + label_pad_token_id=IGNORE_INDEX + ) + + trainer = SFTTrainer( + model=model, + args=new_training_args, + tokenizer=tokenizer, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + callbacks=[LogCallback(metrics_export_address=finetuning_args.metrics_export_address, uid=finetuning_args.uid)], + ) + + trainer = prepare_trainer(trainer) + train_result = trainer.train() + trainer.save_model(training_args.output_dir) + + checkpoint = None + if ray.train.get_context().get_world_rank() == 0: + checkpoint = Checkpoint.from_directory(training_args.output_dir) + ray.train.report(metrics=train_result.metrics, checkpoint=checkpoint) + + +def main(): + print("init") + ray.init() + + training_args, finetuning_args, model_args, data_args = get_train_args() + + print(f"training_args: {training_args}".replace("\n", " ")) + print(finetuning_args) + print(model_args) + print(data_args) + + model_path = model_args.model_name_or_path + use_gpu = True + num_workers = finetuning_args.num_workers + + if data_args.block_size > 0: + global cutoff_len + cutoff_len = data_args.block_size + + # read dataset + print("preprocess_dataset") + columns_map = { + "instruction": "instruction", + "output": "response" + } + if data_args.columns: + print(data_args.columns) + columns_map.update({v: k for k, v in json.loads(data_args.columns).items()}) + + tokenizer = AutoTokenizer.from_pretrained(model_path) + + train_dataset = ray.data.read_csv(data_args.train_path). \ + map_batches(rename_columns, fn_args=[columns_map], batch_format="pandas") + print(train_dataset) + train_dataset = preprocess_dataset(train_dataset, tokenizer, training_args) + + input_datasets = {"train": train_dataset} + + if data_args.evaluation_path: + evaluation_dataset = ray.data.read_csv(data_args.train_path). \ + map_batches(rename_columns, fn_args=[columns_map], batch_format="pandas") + print(evaluation_dataset) + evaluation_dataset = preprocess_dataset(evaluation_dataset, tokenizer, training_args) + input_datasets["evaluation"] = evaluation_dataset + + scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu, + resources_per_worker={"GPU": 1, "CPU": cpus_per_worker}, + trainer_resources={"GPU": 0} + ) + + ray_trainer = TorchTrainer( + train_loop_per_worker=trainer_init_per_worker, + train_loop_config={ + "training_args": training_args, + "finetuning_args": finetuning_args, + "model_args": model_args, + "data_args": data_args, + "tokenizer": tokenizer, + }, + scaling_config=scaling_config, + datasets=input_datasets, + run_config=RunConfig( + storage_path=finetuning_args.storage_path, + # checkpoint_config=ray.train.CheckpointConfig( + # num_to_keep=1, + # checkpoint_score_attribute="eval_loss", + # checkpoint_score_order="min", + # ), + ) + ) + result = ray_trainer.fit() + checkpoint_path = result.checkpoint.path + + print(f"result path {checkpoint_path}") + + file_path = "/home/ray/checkpoint_path" + + directory = os.path.dirname(file_path) + if not os.path.exists(directory): + os.makedirs(directory) + with open(file_path, 'w', encoding='utf-8') as file: + file.write(checkpoint_path) + + +if __name__ == '__main__': + main() diff --git a/pkg/tuning/trainer.py b/pkg/tuning/trainer.py new file mode 100644 index 0000000..6a168e0 --- /dev/null +++ b/pkg/tuning/trainer.py @@ -0,0 +1,507 @@ +import os +import json +import torch, time, math +from copy import deepcopy +from pathlib import Path + +import numpy as np +import torch.nn as nn +from torch.utils.data import Dataset +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, Callable +from transformers import Seq2SeqTrainer, Trainer + +if TYPE_CHECKING: + from transformers.data.data_collator import DataCollator + from transformers.modeling_utils import PreTrainedModel + from transformers.tokenization_utils_base import PreTrainedTokenizerBase + from transformers.trainer_callback import TrainerCallback + from transformers.training_args import TrainingArguments + +from transformers.generation.configuration_utils import GenerationConfig +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.utils import logging +from transformers.trainer_utils import EvalPrediction, PredictionOutput, speed_metrics + +IGNORE_INDEX = -100 +logger = logging.get_logger(__name__) + + +class GenEvalSeq2SeqTrainer(Seq2SeqTrainer): + def __init__( + self, + model: Union["PreTrainedModel", nn.Module] = None, + args: "TrainingArguments" = None, + data_collator: Optional["DataCollator"] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, + model_init: Optional[Callable[[], "PreTrainedModel"]] = None, + compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None, + callbacks: Optional[List["TrainerCallback"]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + gen_args: Optional[Any] = None, # ADD + ): + self.gen_args = gen_args + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + # Override self.model.generation_config if a GenerationConfig is specified in args. + # Priority: args.generation_config > model.generation_config > default GenerationConfig. + if self.args.generation_config is not None: + gen_config = self.load_generation_config(self.args.generation_config) + self.model.generation_config = gen_config + print(f'trainer init : tokenizer{tokenizer}') + + def evaluate( + self, + eval_dataset: Optional[Dataset] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + **gen_kwargs, + ) -> Dict[str, float]: + + # force left pad + bak_tp = self.tokenizer.padding_side + bak_dc = self.data_collator + print(f'**** evaluate origin pad style: {self.tokenizer.padding_side} ****', '\n') + + if self.gen_args is not None and (gen_kwargs is None or len(gen_kwargs) == 0): + logger.info("*" * 5 + "Using Initial Trainer gen_args" + "*" * 5) + gen_kwargs = self.gen_args.copy() + else: + logger.info("*" * 5 + "Using Default Trainer gen_kwargs" + "*" * 5) + gen_kwargs = gen_kwargs.copy() + # 添加您自己的逻辑... + # 在这里扩展 gen_kwargs 或通过其他方法修改参数 + + # 调用父类的 evaluate 方法并返回结果 + self.data_collator = left_collat_fn(self.tokenizer) + # @ eval_dataset 的 datacollator 用的是 right + eval_metrics = super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, + **gen_kwargs) + + self.tokenizer.padding_side = bak_tp + self.data_collator = bak_dc + return eval_metrics + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + r""" + Removes the prompt part in the generated tokens. + + Subclass and override to inject custom behavior. + """ + labels = inputs["labels"].clone() if "labels" in inputs else None # backup labels + # force left pad + if self.args.predict_with_generate: + self.tokenizer.padding_side = "left" + assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." + prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) + if prompt_len > label_len: + inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) + if label_len > prompt_len: + inputs["labels"] = inputs["labels"][:, :prompt_len] # truncate the labels instead of padding the inputs + + loss, generated_tokens, _ = super().prediction_step( + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + ) + if generated_tokens is not None and self.args.predict_with_generate: + generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id + generated_tokens = generated_tokens.contiguous() + + return loss, generated_tokens, labels + + def _pad_tensors_to_target_len( + self, + src_tensor: torch.Tensor, + tgt_tensor: torch.Tensor + ) -> torch.Tensor: + r""" + Pads the tensor to the same length as the target tensor. + """ + assert self.tokenizer.pad_token_id is not None, "Pad token is required." + padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor) + padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding + return padded_tensor.contiguous() # in contiguous memory + + def save_predictions( + self, + predict_results: "PredictionOutput" + ) -> None: + r""" + Saves model predictions to `output_dir`. + A custom behavior that not contained in Seq2SeqTrainer. + 自定义行为 + """ + if not self.is_world_process_zero(): + return + + output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") + logger.info(f"Saving prediction results to {output_prediction_file}") + + preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, + self.tokenizer.pad_token_id) + labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, + self.tokenizer.pad_token_id) + + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, + clean_up_tokenization_spaces=True) + + with open(output_prediction_file, "w", encoding="utf-8") as writer: + res: List[str] = [] + for pred, label in zip(decoded_preds, decoded_labels): + res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) + writer.write("\n".join(res)) + + +class SFTTrainer(Trainer): + def __init__( + self, + model: Union["PreTrainedModel", nn.Module] = None, + args: "TrainingArguments" = None, + data_collator: Optional["DataCollator"] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, + model_init: Optional[Callable[[], "PreTrainedModel"]] = None, + compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None, + callbacks: Optional[List["TrainerCallback"]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ): + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Override self.model.generation_config if a GenerationConfig is specified in args. + # Priority: args.generation_config > model.generation_config > default GenerationConfig. + if self.args.generation_config is not None: + gen_config = self.load_generation_config(self.args.generation_config) + self.model.generation_config = gen_config + + @staticmethod + def load_generation_config(gen_config_arg: Union[str, GenerationConfig]) -> GenerationConfig: + """ + Loads a `~generation.GenerationConfig` from the `Seq2SeqTrainingArguments.generation_config` arguments. + + Args: + gen_config_arg (`str` or [`~generation.GenerationConfig`]): + `Seq2SeqTrainingArguments.generation_config` argument. + + Returns: + A `~generation.GenerationConfig`. + """ + + # GenerationConfig provided, nothing to do + if isinstance(gen_config_arg, GenerationConfig): + return deepcopy(gen_config_arg) + + # str or Path + pretrained_model_name = Path(gen_config_arg) if isinstance(gen_config_arg, str) else gen_config_arg + config_file_name = None + + # Figuring if it is path pointing to a file, pointing to a directory or else a model id or URL + # This step is required in order to determine config_file_name + if pretrained_model_name.is_file(): + config_file_name = pretrained_model_name.name + pretrained_model_name = pretrained_model_name.parent + # dir path + elif pretrained_model_name.is_dir(): + pass + # model id or URL + else: + pretrained_model_name = gen_config_arg + + gen_config = GenerationConfig.from_pretrained(pretrained_model_name, config_file_name) + return gen_config + + def evaluate( + self, + eval_dataset: Optional[Dataset] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + **gen_kwargs, + ) -> Dict[str, float]: + """ + Run evaluation and returns metrics. + + The calling script will be responsible for providing a method to compute metrics, as they are task-dependent + (pass it to the init `compute_metrics` argument). + + You can also subclass and override this method to inject custom behavior. + + Args: + eval_dataset (`Dataset`, *optional*): + Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns + not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` + method. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is `"eval"` (default) + max_length (`int`, *optional*): + The maximum target length to use when predicting with the generate method. + num_beams (`int`, *optional*): + Number of beams for beam search that will be used when predicting with the generate method. 1 means no + beam search. + gen_kwargs: + Additional `generate` specific kwargs. + + Returns: + A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The + dictionary also contains the epoch number which comes from the training state. + """ + + gen_kwargs = gen_kwargs.copy() + if ( + gen_kwargs.get("max_length") is None + and gen_kwargs.get("max_new_tokens") is None + and self.args.generation_max_length is not None + ): + gen_kwargs["max_length"] = self.args.generation_max_length + if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None: + gen_kwargs["num_beams"] = self.args.generation_num_beams + self._gen_kwargs = gen_kwargs + + # **** Original Code Start **** + self._memory_tracker.start() + + eval_dataloader = self.get_eval_dataloader(eval_dataset) + start_time = time.time() + + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( + eval_dataloader, + description="Evaluation", + prediction_loss_only=True if self.compute_metrics is None else None, + ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + # # **** Original Code End **** + # @ compute perplexity + if 'eval_loss' in output.metrics.keys(): + mean_loss = output.metrics.get('eval_loss') + perplexity = math.exp(mean_loss) + output.metrics['eval_perplexity'] = perplexity + + # logger.info(output.metrics) + self.log(output.metrics) + # 加入到 state.log_history 控制台里面 + + # if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + # # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + # xm.master_print(met.metrics_report()) + # @ on_evaluate + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return output.metrics + + def predict( + self, + test_dataset: Dataset, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "test", + **gen_kwargs, + ) -> "PredictionOutput": + """ + Run prediction and returns predictions and potential metrics. + + Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method + will also return metrics, like in `evaluate()`. + + Args: + test_dataset (`Dataset`): + Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. Has to implement the method `__len__` + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is `"eval"` (default) + max_length (`int`, *optional*): + The maximum target length to use when predicting with the generate method. + num_beams (`int`, *optional*): + Number of beams for beam search that will be used when predicting with the generate method. 1 means no + beam search. + gen_kwargs: + Additional `generate` specific kwargs. + + + + If your predictions or labels have different sequence lengths (for instance because you're doing dynamic + padding in a token classification task) the predictions will be padded (on the right) to allow for + concatenation into one array. The padding index is -100. + + + + Returns: *NamedTuple* A namedtuple with the following keys: + + - predictions (`np.ndarray`): The predictions on `test_dataset`. + - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). + - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained + labels). + """ + + gen_kwargs = gen_kwargs.copy() + + # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the + # training args + if ( + gen_kwargs.get("max_length") is None + and gen_kwargs.get("max_new_tokens") is None + and self.args.generation_max_length is not None + ): + gen_kwargs["max_length"] = self.args.generation_max_length + if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None: + gen_kwargs["num_beams"] = self.args.generation_num_beams + self._gen_kwargs = gen_kwargs + + return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + **gen_kwargs, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + + if not self.args.predict_with_generate or prediction_loss_only: + return super().prediction_step( + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + ) + # ADD + has_labels = "labels" in inputs + inputs = self._prepare_inputs(inputs) + + # Priority (handled in generate): + # non-`None` gen_kwargs > model.generation_config > default GenerationConfig() + if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"): + gen_kwargs = self._gen_kwargs.copy() + if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None: + gen_kwargs.pop("num_beams") + if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None: + gen_kwargs.pop("max_length") + + default_synced_gpus = True if is_deepspeed_zero3_enabled() else False + gen_kwargs["synced_gpus"] = ( + gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus + ) + + generation_inputs = inputs.copy() + # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate + # (otherwise, it would continue generating from the padded `decoder_input_ids`) + if ( + "labels" in generation_inputs + and "decoder_input_ids" in generation_inputs + and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape + ): + generation_inputs = { + k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask") + } + # @ 1 generated_tokens + print(gen_kwargs) + generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs) + + # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop + # TODO: remove this hack when the legacy code that initializes generation_config from a model config is + # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183 + if self.model.generation_config._from_model_config: + self.model.generation_config._from_model_config = False + + # Retrieves GenerationConfig from model.generation_config + gen_config = self.model.generation_config + # in case the batch is shorter than max length, the output should be padded + if generated_tokens.shape[-1] < gen_config.max_length: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length) + elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1) + + # @ 2 outputs loss + with torch.no_grad(): + if has_labels: + with self.compute_loss_context_manager(): + outputs = model(**inputs) + if self.label_smoother is not None: + loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() + else: + loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() + else: + loss = None + + if self.args.prediction_loss_only: + return loss, None, None + + if has_labels: + labels = inputs["labels"] + if labels.shape[-1] < gen_config.max_length: + labels = self._pad_tensors_to_max_len(labels, gen_config.max_length) + elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1: + labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1) + else: + labels = None + + return loss, generated_tokens, labels + + def _pad_tensors_to_max_len(self, tensor, max_length): + if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): + # If PAD token is not defined at least EOS token has to be defined + pad_token_id = ( + self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + ) + else: + if self.model.config.pad_token_id is not None: + pad_token_id = self.model.config.pad_token_id + else: + raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") + + padded_tensor = pad_token_id * torch.ones( + (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device + ) + padded_tensor[:, : tensor.shape[-1]] = tensor + return padded_tensor