diff --git a/pkg/common/v1alpha2/common.go b/pkg/common/v1alpha2/common.go index 4671c4af30f..44a909bf98b 100644 --- a/pkg/common/v1alpha2/common.go +++ b/pkg/common/v1alpha2/common.go @@ -24,3 +24,19 @@ func GetSupportedJobList() []schema.GroupVersionKind { } return supportedJobList } + +func GetJobLabelMap(jobKind string, trialName string) map[string]string { + labelMap := make(map[string]string) + + if jobKind == "TFJob" { + labelMap["tf-job-name"] = trialName + labelMap["tf-job-role"] = "master" + } else if jobKind == "PyTorchJob" { + labelMap["pytorch-job-name"] = trialName + labelMap["pytorch-job-role"] = "master" + } else { + labelMap["job-name"] = trialName + } + + return labelMap +} diff --git a/pkg/util/v1alpha2/metricscollector/metricscollector.go b/pkg/util/v1alpha2/metricscollector/metricscollector.go index ec67914c6b1..7ed4119ebb4 100644 --- a/pkg/util/v1alpha2/metricscollector/metricscollector.go +++ b/pkg/util/v1alpha2/metricscollector/metricscollector.go @@ -14,6 +14,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/config" v1alpha2 "github.com/kubeflow/katib/pkg/api/v1alpha2" + commonv1alpha2 "github.com/kubeflow/katib/pkg/common/v1alpha2" ) type MetricsCollector struct { @@ -36,11 +37,7 @@ func NewMetricsCollector() (*MetricsCollector, error) { } func (d *MetricsCollector) CollectObservationLog(tId string, jobKind string, metrics []string, namespace string) (*v1alpha2.ObservationLog, error) { - labelMap := make(map[string]string) - - // TODO: Add labels for TFJob and PytorchJob - labelMap["job-name"] = tId - + labelMap := commonv1alpha2.GetJobLabelMap(jobKind, tId) pl, err := d.clientset.CoreV1().Pods(namespace).List(metav1.ListOptions{LabelSelector: labels.Set(labelMap).String(), IncludeUninitialized: true}) if err != nil { return nil, err