diff --git a/docs/workloads/task/configuration.md b/docs/workloads/task/configuration.md index ca53d98db5..84c31312c6 100644 --- a/docs/workloads/task/configuration.md +++ b/docs/workloads/task/configuration.md @@ -12,6 +12,7 @@ conda: <string> # relative path to conda-packages.txt (default: conda-packages.txt) shell: <string> # relative path to a shell script for system package installation (default: dependencies.sh) python_path: <string> # path to the root of your Python folder that will be appended to PYTHONPATH (default: folder containing cortex.yaml) + shm_size: <string> # size of shared memory (/dev/shm) for sharing data between multiple processes, e.g. 64Mi or 1Gi (default: Null) image: <string> # docker image to use for the Task (default: quay.io/cortexlabs/python-handler-cpu:master, quay.io/cortexlabs/python-handler-gpu:master-cuda10.2-cudnn8, or quay.io/cortexlabs/python-handler-inf:master based on compute) env: <string: string> # dictionary of environment variables log_level: <string> # log level that can be "debug", "info", "warning" or "error" (default: "info") diff --git a/pkg/operator/operator/k8s.go b/pkg/operator/operator/k8s.go index df00a1dbe4..b16702f321 100644 --- a/pkg/operator/operator/k8s.go +++ b/pkg/operator/operator/k8s.go @@ -277,6 +277,22 @@ func TaskContainers(api *spec.API) ([]kcore.Container, []kcore.Volume) { } } + if api.TaskDefinition.ShmSize != nil { + volumes = append(volumes, kcore.Volume{ + Name: "dshm", + VolumeSource: kcore.VolumeSource{ + EmptyDir: &kcore.EmptyDirVolumeSource{ + Medium: kcore.StorageMediumMemory, + SizeLimit: k8s.QuantityPtr(api.TaskDefinition.ShmSize.Quantity), + }, + }, + }) + apiPodVolumeMounts = append(apiPodVolumeMounts, kcore.VolumeMount{ + Name: "dshm", + MountPath: "/dev/shm", + }) + } + containers = append(containers, kcore.Container{ Name: APIContainerName, Image: api.TaskDefinition.Image, diff --git a/pkg/types/spec/errors.go b/pkg/types/spec/errors.go index 16a36b9cbe..082438210b 100644 --- a/pkg/types/spec/errors.go +++ b/pkg/types/spec/errors.go @@ -232,10 +232,10 @@ func ErrorSurgeAndUnavailableBothZero() error { }) } -func ErrorShmSizeCannotExceedMem(shmSize k8s.Quantity, mem k8s.Quantity) error { +func ErrorShmSizeCannotExceedMem(parentFieldName string, shmSize k8s.Quantity, mem k8s.Quantity) error { return errors.WithStack(&errors.Error{ Kind: ErrShmSizeCannotExceedMem, - Message: fmt.Sprintf("handler.shm_size (%s) cannot exceed compute.mem (%s)", shmSize.UserString, mem.UserString), + Message: fmt.Sprintf("%s.shm_size (%s) cannot exceed compute.mem (%s)", parentFieldName, shmSize.UserString, mem.UserString), }) } diff --git a/pkg/types/spec/validations.go b/pkg/types/spec/validations.go index 2c3b42b02a..000313aede 100644 --- a/pkg/types/spec/validations.go +++ b/pkg/types/spec/validations.go @@ -310,6 +310,14 @@ func taskDefinitionValidation() *cr.StructFieldValidation { DockerImageOrEmpty: true, }, }, + { + StructField: "ShmSize", + StringPtrValidation: &cr.StringPtrValidation{ + Default: nil, + AllowExplicitNull: true, + }, + Parser: k8s.QuantityParser(&k8s.QuantityValidation{}), + }, { StructField: "LogLevel", StringValidation: &cr.StringValidation{ @@ -803,7 +811,13 @@ func ValidateAPI( if api.Handler != nil && api.Handler.ShmSize != nil && api.Compute.Mem != nil { if api.Handler.ShmSize.Cmp(api.Compute.Mem.Quantity) > 0 { - return ErrorShmSizeCannotExceedMem(*api.Handler.ShmSize, *api.Compute.Mem) + return ErrorShmSizeCannotExceedMem(userconfig.HandlerKey, *api.Handler.ShmSize, *api.Compute.Mem) + } + } + + if api.TaskDefinition != nil && api.TaskDefinition.ShmSize != nil && api.Compute.Mem != nil { + if api.TaskDefinition.ShmSize.Cmp(api.Compute.Mem.Quantity) > 0 { + return ErrorShmSizeCannotExceedMem(userconfig.TaskDefinitionKey, *api.TaskDefinition.ShmSize, *api.Compute.Mem) } } diff --git a/pkg/types/userconfig/api.go b/pkg/types/userconfig/api.go index 152c592a72..a38e09de1b 100644 --- a/pkg/types/userconfig/api.go +++ b/pkg/types/userconfig/api.go @@ -69,6 +69,7 @@ type TaskDefinition struct { Path string `json:"path" yaml:"path"` PythonPath *string `json:"python_path" yaml:"python_path"` Image string `json:"image" yaml:"image"` + ShmSize *k8s.Quantity `json:"shm_size" yaml:"shm_size"` LogLevel LogLevel `json:"log_level" yaml:"log_level"` Config map[string]interface{} `json:"config" yaml:"config"` Env map[string]string `json:"env" yaml:"env"` @@ -397,6 +398,9 @@ func (task *TaskDefinition) UserStr() string { sb.WriteString(fmt.Sprintf("%s: %s\n", PythonPathKey, *task.PythonPath)) } sb.WriteString(fmt.Sprintf("%s: %s\n", ImageKey, task.Image)) + if task.ShmSize != nil { + sb.WriteString(fmt.Sprintf("%s: %s\n", ShmSizeKey, task.ShmSize.String())) + } sb.WriteString(fmt.Sprintf("%s: %s\n", LogLevelKey, task.LogLevel)) if len(task.Config) > 0 { sb.WriteString(fmt.Sprintf("%s:\n", ConfigKey)) @@ -442,7 +446,7 @@ func (handler *Handler) UserStr() string { sb.WriteString(fmt.Sprintf("%s: %s\n", ThreadsPerProcessKey, s.Int32(handler.ThreadsPerProcess))) if handler.ShmSize != nil { - sb.WriteString(fmt.Sprintf("%s: %s\n", ShmSize, handler.ShmSize.UserString)) + sb.WriteString(fmt.Sprintf("%s: %s\n", ShmSizeKey, handler.ShmSize.UserString)) } if len(handler.Config) > 0 { diff --git a/pkg/types/userconfig/config_key.go b/pkg/types/userconfig/config_key.go index 10b6fbf7b2..70113ae0b7 100644 --- a/pkg/types/userconfig/config_key.go +++ b/pkg/types/userconfig/config_key.go @@ -42,7 +42,7 @@ const ( TensorFlowServingImageKey = "tensorflow_serving_image" ProcessesPerReplicaKey = "processes_per_replica" ThreadsPerProcessKey = "threads_per_process" - ShmSize = "shm_size" + ShmSizeKey = "shm_size" LogLevelKey = "log_level" ConfigKey = "config" EnvKey = "env"