Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device Name parsing #1490

Merged
merged 1 commit into from
Feb 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions api/src/main/java/ai/djl/Device.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@code
Expand All @@ -33,6 +35,8 @@ public final class Device {
private static final Device CPU = new Device(Type.CPU, -1);
private static final Device GPU = Device.of(Type.GPU, 0);

private static final Pattern DEVICE_NAME = Pattern.compile("([a-z]+)([0-9]*)");

private String deviceType;
private int deviceId;

Expand Down Expand Up @@ -63,6 +67,46 @@ public static Device of(String deviceType, int deviceId) {
return CACHE.computeIfAbsent(key, k -> new Device(deviceType, deviceId));
}

/**
* Parses a deviceName string into a device.
*
* <p>The main format of a device name string is "cpu", "gpu0", or "nc1". This is simply
* deviceType concatenated with the deviceId. If no deviceId is used, -1 will be assumed.
*
* <p>There are also several simplified formats. "-1", "", and null deviceNames correspond to
* cpu. Non-negative integer deviceNames such as "0", "1", or "2" correspond to gpus with those
* deviceIds.
*
* @param deviceName deviceName string
* @return the device
*/
public static Device fromName(String deviceName) {
if (deviceName == null || deviceName.isEmpty()) {
return Device.cpu();
}

Matcher matcher = DEVICE_NAME.matcher(deviceName);
if (matcher.matches()) {
String deviceType = matcher.group(1);
int deviceId = -1;
if (!matcher.group(2).isEmpty()) {
deviceId = Integer.parseInt(matcher.group(2));
}
return Device.of(deviceType, deviceId);
}

try {
int deviceId = Integer.parseInt(deviceName);

if (deviceId < 0) {
return Device.cpu();
}
return Device.gpu(deviceId);
} catch (NumberFormatException ignored) {
}
throw new IllegalArgumentException("Failed to parse device name: " + deviceName);
}

/**
* Returns the device type of the Device.
*
Expand Down
15 changes: 15 additions & 0 deletions api/src/test/java/ai/djl/DeviceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,19 @@ public void testDevice() {

Engine.debugEnvironment();
}

@Test
public void testDeviceName() {
Assert.assertEquals(Device.fromName("cpu"), Device.cpu());
Assert.assertEquals(Device.fromName(""), Device.cpu());
Assert.assertEquals(Device.fromName("-1"), Device.cpu());
Assert.assertEquals(Device.fromName(null), Device.cpu());

Assert.assertEquals(Device.fromName("gpu0"), Device.gpu());
Assert.assertEquals(Device.fromName("0"), Device.gpu());
Assert.assertEquals(Device.fromName("1"), Device.gpu(1));

Assert.assertEquals(Device.fromName("nc1"), Device.of("nc", 1));
Assert.assertEquals(Device.fromName("a999"), Device.of("a", 999));
}
}