Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to set shared memory (shm) size for Task API #2132

Merged
merged 4 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/workloads/task/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
conda: <string> # relative path to conda-packages.txt (default: conda-packages.txt)
shell: <string> # relative path to a shell script for system package installation (default: dependencies.sh)
python_path: <string> # path to the root of your Python folder that will be appended to PYTHONPATH (default: folder containing cortex.yaml)
shm_size: <string> # size of shared memory (/dev/shm) for sharing data between multiple processes, e.g. 64Mi or 1Gi (default: Null)
image: <string> # docker image to use for the Task (default: quay.io/cortexlabs/python-handler-cpu:master, quay.io/cortexlabs/python-handler-gpu:master-cuda10.2-cudnn8, or quay.io/cortexlabs/python-handler-inf:master based on compute)
env: <string: string> # dictionary of environment variables
log_level: <string> # log level that can be "debug", "info", "warning" or "error" (default: "info")
Expand Down
16 changes: 16 additions & 0 deletions pkg/operator/operator/k8s.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,22 @@ func TaskContainers(api *spec.API) ([]kcore.Container, []kcore.Volume) {
}
}

if api.TaskDefinition.ShmSize != nil {
volumes = append(volumes, kcore.Volume{
Name: "dshm",
VolumeSource: kcore.VolumeSource{
EmptyDir: &kcore.EmptyDirVolumeSource{
Medium: kcore.StorageMediumMemory,
SizeLimit: k8s.QuantityPtr(api.TaskDefinition.ShmSize.Quantity),
},
},
})
apiPodVolumeMounts = append(apiPodVolumeMounts, kcore.VolumeMount{
Name: "dshm",
MountPath: "/dev/shm",
})
}

containers = append(containers, kcore.Container{
Name: APIContainerName,
Image: api.TaskDefinition.Image,
Expand Down
4 changes: 2 additions & 2 deletions pkg/types/spec/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,10 @@ func ErrorSurgeAndUnavailableBothZero() error {
})
}

func ErrorShmSizeCannotExceedMem(shmSize k8s.Quantity, mem k8s.Quantity) error {
func ErrorShmSizeCannotExceedMem(parentFieldName string, shmSize k8s.Quantity, mem k8s.Quantity) error {
return errors.WithStack(&errors.Error{
Kind: ErrShmSizeCannotExceedMem,
Message: fmt.Sprintf("handler.shm_size (%s) cannot exceed compute.mem (%s)", shmSize.UserString, mem.UserString),
Message: fmt.Sprintf("%s.shm_size (%s) cannot exceed compute.mem (%s)", parentFieldName, shmSize.UserString, mem.UserString),
})
}

Expand Down
16 changes: 15 additions & 1 deletion pkg/types/spec/validations.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ func taskDefinitionValidation() *cr.StructFieldValidation {
DockerImageOrEmpty: true,
},
},
{
StructField: "ShmSize",
StringPtrValidation: &cr.StringPtrValidation{
Default: nil,
AllowExplicitNull: true,
},
Parser: k8s.QuantityParser(&k8s.QuantityValidation{}),
},
{
StructField: "LogLevel",
StringValidation: &cr.StringValidation{
Expand Down Expand Up @@ -803,7 +811,13 @@ func ValidateAPI(

if api.Handler != nil && api.Handler.ShmSize != nil && api.Compute.Mem != nil {
if api.Handler.ShmSize.Cmp(api.Compute.Mem.Quantity) > 0 {
return ErrorShmSizeCannotExceedMem(*api.Handler.ShmSize, *api.Compute.Mem)
return ErrorShmSizeCannotExceedMem(userconfig.HandlerKey, *api.Handler.ShmSize, *api.Compute.Mem)
}
}

if api.TaskDefinition != nil && api.TaskDefinition.ShmSize != nil && api.Compute.Mem != nil {
if api.TaskDefinition.ShmSize.Cmp(api.Compute.Mem.Quantity) > 0 {
return ErrorShmSizeCannotExceedMem(userconfig.TaskDefinitionKey, *api.TaskDefinition.ShmSize, *api.Compute.Mem)
}
}

Expand Down
6 changes: 5 additions & 1 deletion pkg/types/userconfig/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type TaskDefinition struct {
Path string `json:"path" yaml:"path"`
PythonPath *string `json:"python_path" yaml:"python_path"`
Image string `json:"image" yaml:"image"`
ShmSize *k8s.Quantity `json:"shm_size" yaml:"shm_size"`
LogLevel LogLevel `json:"log_level" yaml:"log_level"`
Config map[string]interface{} `json:"config" yaml:"config"`
Env map[string]string `json:"env" yaml:"env"`
Expand Down Expand Up @@ -397,6 +398,9 @@ func (task *TaskDefinition) UserStr() string {
sb.WriteString(fmt.Sprintf("%s: %s\n", PythonPathKey, *task.PythonPath))
}
sb.WriteString(fmt.Sprintf("%s: %s\n", ImageKey, task.Image))
if task.ShmSize != nil {
sb.WriteString(fmt.Sprintf("%s: %s\n", ShmSizeKey, task.ShmSize.String()))
}
sb.WriteString(fmt.Sprintf("%s: %s\n", LogLevelKey, task.LogLevel))
if len(task.Config) > 0 {
sb.WriteString(fmt.Sprintf("%s:\n", ConfigKey))
Expand Down Expand Up @@ -442,7 +446,7 @@ func (handler *Handler) UserStr() string {
sb.WriteString(fmt.Sprintf("%s: %s\n", ThreadsPerProcessKey, s.Int32(handler.ThreadsPerProcess)))

if handler.ShmSize != nil {
sb.WriteString(fmt.Sprintf("%s: %s\n", ShmSize, handler.ShmSize.UserString))
sb.WriteString(fmt.Sprintf("%s: %s\n", ShmSizeKey, handler.ShmSize.UserString))
}

if len(handler.Config) > 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/types/userconfig/config_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ const (
TensorFlowServingImageKey = "tensorflow_serving_image"
ProcessesPerReplicaKey = "processes_per_replica"
ThreadsPerProcessKey = "threads_per_process"
ShmSize = "shm_size"
ShmSizeKey = "shm_size"
LogLevelKey = "log_level"
ConfigKey = "config"
EnvKey = "env"
Expand Down