Skip to content

Commit

Permalink
Fix #1740: Handle OPTIONS in management API (#1774)
Browse files Browse the repository at this point in the history
* Fix #1740: Handle OPTIONS in management API

* Add unit tests

* Add unit test resource file

* Unit test changes

Co-authored-by: Aaqib <maaquib@gmail.com>
  • Loading branch information
xyang16 and maaquib authored Aug 15, 2022
1 parent 358f97e commit a1a0031
Show file tree
Hide file tree
Showing 5 changed files with 1,320 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.pytorch.serve.http.messages.ListModelsResponse;
import org.pytorch.serve.http.messages.RegisterModelRequest;
import org.pytorch.serve.job.RestJob;
import org.pytorch.serve.openapi.OpenApiUtils;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.util.ApiUtils;
import org.pytorch.serve.util.JsonUtils;
Expand Down Expand Up @@ -92,6 +93,15 @@ public void handleRequest(
}
} else if (HttpMethod.DELETE.equals(method)) {
handleUnregisterModel(ctx, segments[2], modelVersion);
} else if (HttpMethod.OPTIONS.equals(method)) {
ModelManager modelManager = ModelManager.getInstance();
Model model = modelManager.getModel(segments[2], modelVersion);
if (model == null) {
throw new ModelNotFoundException("Model not found: " + segments[2]);
}

String resp = OpenApiUtils.getModelManagementApi(model);
NettyUtils.sendJsonResponse(ctx, resp);
} else {
throw new MethodNotAllowedException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,22 @@ public static String getModelApi(Model model) {
return JsonUtils.GSON_PRETTY.toJson(openApi);
}

public static String getModelManagementApi(Model model) {
String modelName = model.getModelName();
OpenApi openApi = new OpenApi();
Info info = new Info();
info.setTitle("RESTful Management API for: " + modelName);
ConfigManager config = ConfigManager.getInstance();
info.setVersion(config.getProperty("version", null));
openApi.setInfo(info);

openApi.addPath("/models/{model_name}", getModelManagerPath(false));
openApi.addPath("/models/{model_name}/{model_version}", getModelManagerPath(true));
openApi.addPath("/models/{model_name}/{model_version}/set-default", getSetDefaultPath());

return JsonUtils.GSON_PRETTY.toJson(openApi);
}

private static Path getApiDescriptionPath(String operationID, boolean legacy) {
Schema schema = new Schema("object");
schema.addProperty("openapi", new Schema("string"), true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class ModelServerTest {
private String listManagementApisResult;
private String listMetricsApisResult;
private String noopApiResult;
private String noopManagementApiResult;

static {
TestUtils.init();
Expand Down Expand Up @@ -99,6 +100,11 @@ public void beforeSuite()
try (InputStream is = new FileInputStream("src/test/resources/describe_api.json")) {
noopApiResult = IOUtils.toString(is, StandardCharsets.UTF_8.name());
}

try (InputStream is = new FileInputStream("src/test/resources/model_management_api.json")) {
noopManagementApiResult =
String.format(IOUtils.toString(is, StandardCharsets.UTF_8.name()), version);
}
}

@AfterClass
Expand Down Expand Up @@ -193,6 +199,21 @@ public void testDescribeApi() throws InterruptedException {
@Test(
alwaysRun = true,
dependsOnMethods = {"testDescribeApi"})
public void testModelManagementApi() throws InterruptedException {
Channel channel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
TestUtils.describeModelManagementApi(channel, "noop");
TestUtils.getLatch().await();

Assert.assertEquals(
TestUtils.getResult().replaceAll("(\\\\r|\r\n|\n|\n\r)", "\r"),
noopManagementApiResult.replaceAll("(\\\\r|\r\n|\n|\n\r)", "\r"));
}

@Test(
alwaysRun = true,
dependsOnMethods = {"testModelManagementApi"})
public void testInitialWorkers() throws InterruptedException {
Channel channel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
Expand Down
10 changes: 10 additions & 0 deletions frontend/server/src/test/java/org/pytorch/serve/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,13 @@ public final class TestUtils {
private TestUtils() {}

public static void init() {
System.out.println("init");
// set up system properties for local IDE debug
System.out.println("tsConfigFile1: " + System.getProperty("tsConfigFile"));
if (System.getProperty("tsConfigFile") == null) {
System.setProperty("tsConfigFile", "src/test/resources/config.properties");
}
System.out.println("tsConfigFile2: " + System.getProperty("tsConfigFile"));
if (System.getProperty("METRICS_LOCATION") == null) {
System.setProperty("METRICS_LOCATION", "build/logs");
}
Expand Down Expand Up @@ -260,6 +263,13 @@ public static void describeModelApi(Channel channel, String modelName) {
channel.writeAndFlush(req);
}

public static void describeModelManagementApi(Channel channel, String modelName) {
HttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/models/" + modelName);
channel.writeAndFlush(req);
}

public static void describeModel(
Channel channel, String modelName, String version, boolean customized) {
String requestURL = "/models/" + modelName;
Expand Down
Loading

0 comments on commit a1a0031

Please sign in to comment.