Skip to content

Commit

Permalink
[api] Allows to configure custom batchifer (#3559)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Dec 17, 2024
1 parent 77d2ded commit c0d3a37
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
8 changes: 7 additions & 1 deletion api/src/main/java/ai/djl/translate/Batchifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.translate;

import ai.djl.ndarray.NDList;
import ai.djl.util.ClassLoaderUtils;

import java.io.Serializable;
import java.util.Arrays;
Expand Down Expand Up @@ -47,7 +48,12 @@ static Batchifier fromString(String name) {
case "none":
return null;
default:
throw new IllegalArgumentException("Invalid batchifier name");
ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
Batchifier b = ClassLoaderUtils.initClass(cl, Batchifier.class, name);
if (b == null) {
throw new IllegalArgumentException("Invalid batchifier name: " + name);
}
return b;
}
}

Expand Down
11 changes: 11 additions & 0 deletions api/src/test/java/ai/djl/translate/BatchifierTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ public void testBatchifier() {
Batchifier batchifier = Batchifier.fromString("stack");
Assert.assertEquals(batchifier, Batchifier.STACK);

batchifier = Batchifier.fromString("none");
Assert.assertNull(batchifier);

batchifier = Batchifier.fromString("padding");
Assert.assertNotNull(batchifier);
Assert.assertEquals(batchifier.getClass(), SimplePaddingStackBatchifier.class);

batchifier = Batchifier.fromString("ai.djl.translate.SimplePaddingStackBatchifier");
Assert.assertNotNull(batchifier);
Assert.assertEquals(batchifier.getClass(), SimplePaddingStackBatchifier.class);

Assert.assertThrows(() -> Batchifier.fromString("invalid"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ public void testBulkEqualsNonBulk() throws IOException, TranslateException {
.optUsage(Dataset.Usage.TEST)
.optRepository(repository)
.setSampling(32, false)
.optLabelBatchifier(new StackBatchifier() {})
.optLabelBatchifier(new StackBatchifier())
.build();

try (Trainer trainer = model.newTrainer(config)) {
Expand Down

0 comments on commit c0d3a37

Please sign in to comment.