diff --git a/ray-operator/controllers/ray/raycluster_controller.go b/ray-operator/controllers/ray/raycluster_controller.go index 047768a6e0..a8a4be5313 100644 --- a/ray-operator/controllers/ray/raycluster_controller.go +++ b/ray-operator/controllers/ray/raycluster_controller.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "reflect" + "regexp" "runtime" "strconv" "strings" @@ -1619,8 +1620,12 @@ func (r *RayClusterReconciler) updateRayClusterStatus(ctx context.Context, origi func sumGPUs(resources map[corev1.ResourceName]resource.Quantity) resource.Quantity { totalGPUs := resource.Quantity{} + // Define a regular expression to match valid GPU resource names + gpuPattern := regexp.MustCompile(`^(.*?/)?(gpu(-count)?|gpu-core\.percentage)$`) + for key, val := range resources { - if strings.HasSuffix(string(key), "gpu") && !val.IsZero() { + // Check if the resource name matches the GPU pattern and if the quantity is non-zero + if gpuPattern.MatchString(string(key)) && !val.IsZero() { totalGPUs.Add(val) } } diff --git a/ray-operator/controllers/ray/raycluster_controller_unit_test.go b/ray-operator/controllers/ray/raycluster_controller_unit_test.go index 4897f48290..9e0e0a7b06 100644 --- a/ray-operator/controllers/ray/raycluster_controller_unit_test.go +++ b/ray-operator/controllers/ray/raycluster_controller_unit_test.go @@ -2933,6 +2933,8 @@ func TestReconcile_NumOfHosts(t *testing.T) { func TestSumGPUs(t *testing.T) { nvidiaGPUResourceName := corev1.ResourceName("nvidia.com/gpu") googleTPUResourceName := corev1.ResourceName("google.com/tpu") + volcanoGPUCorePercentageResourceName := corev1.ResourceName("volcano.sh/gpu-core.percentage") + aliyunGPUCountResourceName := corev1.ResourceName("aliyun.com/gpu-count") tests := map[string]struct { input map[corev1.ResourceName]resource.Quantity @@ -2961,6 +2963,18 @@ func TestSumGPUs(t *testing.T) { }, resource.MustParse("5"), }, + "volcano GPU type specified": { + map[corev1.ResourceName]resource.Quantity{ + volcanoGPUCorePercentageResourceName: resource.MustParse("5"), + }, + resource.MustParse("5"), + }, + "aliyun GPU type specified": { + map[corev1.ResourceName]resource.Quantity{ + aliyunGPUCountResourceName: resource.MustParse("5"), + }, + resource.MustParse("5"), + }, } for name, tc := range tests {