Skip to content

Commit

Permalink
[batchapi] add list batch
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Apr 18, 2024
1 parent 6c5b46b commit 5baeb53
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/main/java/io/github/stefanbratanov/jvm/openai/BatchClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import java.net.http.HttpRequest.BodyPublishers;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
Expand Down Expand Up @@ -71,4 +73,28 @@ public Batch cancelBatch(String batchId) {
HttpResponse<byte[]> httpResponse = sendHttpRequest(httpRequest);
return deserializeResponse(httpResponse.body(), Batch.class);
}

/**
* List your organization's batches
*
* @param after A cursor for use in pagination. after is an object ID that defines your place in
* the list.
* @param limit A limit on the number of objects to be returned.
* @throws OpenAIException in case of API errors
*/
public PaginatedBatches listBatches(Optional<String> after, Optional<String> limit) {
String queryParameters =
createQueryParameters(
Map.of(Constants.LIMIT_QUERY_PARAMETER, limit, Constants.AFTER_QUERY_PARAMETER, after));
HttpRequest httpRequest =
newHttpRequestBuilder()
.uri(baseUrl.resolve(Endpoint.BATCHES.getPath() + queryParameters))
.GET()
.build();
HttpResponse<byte[]> httpResponse = sendHttpRequest(httpRequest);
return deserializeResponse(httpResponse.body(), PaginatedBatches.class);
}

public record PaginatedBatches(
List<Batch> data, String firstId, String lastId, boolean hasMore) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,16 @@ void testBatchClient() {
assertThat(batch.inputFileId()).isEqualTo(inputFile.id());
assertThat(batch.errors()).isNull();

BatchClient.PaginatedBatches paginatedBatches =
batchClient.listBatches(Optional.empty(), Optional.empty());

assertThat(paginatedBatches.data()).isNotEmpty();
assertThat(paginatedBatches.firstId()).isNotNull();
assertThat(paginatedBatches.lastId()).isNotNull();
// assert that the batch we just created is listed
assertThat(paginatedBatches.data())
.anySatisfy(listedBatch -> assertThat(listedBatch.id()).isEqualTo(batch.id()));

// immediately cancel the batch, because can't wait for batch to finish in tests
Batch cancelledBatch = batchClient.cancelBatch(batch.id());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.RepeatedTest;

class OpenApiSpecificationValidationTest {
Expand Down Expand Up @@ -145,6 +146,16 @@ void validateBatch() {
request,
response,
"Instance type (integer) does not match any allowed primitive type (allowed: [\"string\"]");

BatchClient.PaginatedBatches paginatedBatches = testDataUtil.randomPaginatedBatches();

Response listBatchesResponse = createResponseWithBody(serializeObject(paginatedBatches));

validate(
"/" + Endpoint.BATCHES.getPath(),
Method.GET,
listBatchesResponse,
"Instance type (integer) does not match any allowed primitive type (allowed: [\"string\"]");
}

@RepeatedTest(50)
Expand Down Expand Up @@ -200,6 +211,7 @@ void validateModerations() {
validate(request, response);
}

@Disabled("V1 is legacy")
@RepeatedTest(50)
void validateAssistants() {
CreateAssistantRequest createAssistantRequest = testDataUtil.randomCreateAssistantRequest();
Expand Down Expand Up @@ -227,6 +239,7 @@ void validateAssistants() {
validate(request, response);
}

@Disabled("V1 is legacy")
@RepeatedTest(50)
void validateThreads() {
CreateThreadRequest createThreadRequest = testDataUtil.randomCreateThreadRequest();
Expand All @@ -252,6 +265,7 @@ void validateThreads() {
validate(request, response);
}

@Disabled("V1 is legacy")
@RepeatedTest(50)
void validateMessages() {
CreateMessageRequest createMessageRequest = testDataUtil.randomCreateMessageRequest();
Expand All @@ -278,6 +292,7 @@ void validateMessages() {
response);
}

@Disabled("V1 is legacy")
@RepeatedTest(50)
void validateRuns() {
CreateRunRequest createRunRequest = testDataUtil.randomCreateRunRequest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ public CreateBatchRequest randomCreateBatchRequest() {
.build();
}

public BatchClient.PaginatedBatches randomPaginatedBatches() {
return new BatchClient.PaginatedBatches(
listOf(randomInt(1, 20), this::randomBatch),
randomString(5),
randomString(5),
randomBoolean());
}

public File randomFile() {
return new File(
randomString(15),
Expand Down

0 comments on commit 5baeb53

Please sign in to comment.