Skip to content

Commit

Permalink
If all devices are pre-volta, skip setting set_default_device pinned …
Browse files Browse the repository at this point in the history
…mem_limit and set_default_active_thread_percentage.
  • Loading branch information
bcc829 committed Sep 30, 2024
1 parent 71c1fa7 commit 09dcb77
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions cmd/mps-control-daemon/mps/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,21 @@ func (d *Daemon) Start() error {
return err
}

for index, limit := range d.perDevicePinnedDeviceMemoryLimits() {
_, err := d.EchoPipeToControl(fmt.Sprintf("set_default_device_pinned_mem_limit %s %s", index, limit))
if err != nil {
return fmt.Errorf("error setting pinned memory limit for device %v: %w", index, err)
if isAllDevicesPreVolta := d.isAllDevicesPreVolta(); !isAllDevicesPreVolta {
for index, limit := range d.perDevicePinnedDeviceMemoryLimits() {
_, err := d.EchoPipeToControl(fmt.Sprintf("set_default_device_pinned_mem_limit %s %s", index, limit))
if err != nil {
return fmt.Errorf("error setting pinned memory limit for device %v: %w", index, err)
}
}
}
if threadPercentage := d.activeThreadPercentage(); threadPercentage != "" {
_, err := d.EchoPipeToControl(fmt.Sprintf("set_default_active_thread_percentage %s", threadPercentage))
if err != nil {
return fmt.Errorf("error setting active thread percentage: %w", err)
if threadPercentage := d.activeThreadPercentage(); threadPercentage != "" {
_, err := d.EchoPipeToControl(fmt.Sprintf("set_default_active_thread_percentage %s", threadPercentage))
if err != nil {
return fmt.Errorf("error setting active thread percentage: %w", err)
}
}
}

statusFile, err := os.Create(d.startedFile())
if err != nil {
return err
Expand Down Expand Up @@ -278,3 +280,17 @@ func (m *Daemon) activeThreadPercentage() string {

return fmt.Sprintf("%d", 100/replicasPerDevice)
}

func (m *Daemon) isAllDevicesPreVolta() bool {
if len(m.Devices()) == 0 {
return false
}

for _, device := range m.Devices() {
if isVoltaDevice := (*mpsDevice)(device).isAtLeastVolta(); isVoltaDevice {
return false
}
}

return true
}

0 comments on commit 09dcb77

Please sign in to comment.