Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI Functions #35748

Merged
merged 26 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
cdee6e0
Added preliminary support for OpenAI Functions
jpalvarezl Jul 5, 2023
cc124c2
WIP: serialization still fails
jpalvarezl Jul 7, 2023
f492635
Tests are passing, requests go through to nonAzure OAI
jpalvarezl Jul 8, 2023
b87df26
Moved gened classes with polymorphism to impl package
jpalvarezl Jul 8, 2023
0dfb11f
Added assertion for function call
jpalvarezl Jul 10, 2023
0573d73
Added test for usage of function not supplied in the request
jpalvarezl Jul 10, 2023
10c4696
Renamed test for clarity
jpalvarezl Jul 10, 2023
2bbb6ff
Addressed most of the PR comments
jpalvarezl Jul 10, 2023
c9f4280
Added sync version of the tests
jpalvarezl Jul 10, 2023
71e48bf
Renamed runner method
jpalvarezl Jul 10, 2023
a1340cb
Added Azure sync/async tests
jpalvarezl Jul 10, 2023
00d6762
Added docs for FunctionCall
jpalvarezl Jul 10, 2023
72101cb
Removed unused files from samples package
jpalvarezl Jul 10, 2023
abe5ff3
Preparing release notes
jpalvarezl Jul 10, 2023
c9a4fde
Renamed static member
jpalvarezl Jul 11, 2023
14a12af
Moved the exception handling one level up in the call stack
jpalvarezl Jul 11, 2023
aed004b
code regened
jpalvarezl Jul 11, 2023
a5eada6
Moved custom models under their own package
jpalvarezl Jul 12, 2023
2619f35
Updated test records
jpalvarezl Jul 12, 2023
19a9c2c
Addressed most of the style checks
jpalvarezl Jul 12, 2023
5334715
removed unused import
jpalvarezl Jul 12, 2023
671cf53
Merge branch 'main' into jpalvarezl/aoai_functions
jpalvarezl Jul 13, 2023
1411a63
Renamed type
jpalvarezl Jul 13, 2023
d827c17
Added spell checker exception for DALL-E
jpalvarezl Jul 13, 2023
4eca3e9
Updated commit hash and re-ran code gen
jpalvarezl Jul 13, 2023
3dff5f4
Renamed param
jpalvarezl Jul 13, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@
"creds",
"credscan",
"curr",
"DALL-E",
"databind",
"databricks",
"DAZURE",
Expand Down
4 changes: 4 additions & 0 deletions sdk/openai/azure-ai-openai/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## 1.0.0-beta.3 (Unreleased)

- Added methods and models to support DALL-E
- Added methods and models to support Functions
- Added models supporting ResponsibleAI annotations

### Features Added

### Breaking Changes
Expand Down
2 changes: 1 addition & 1 deletion sdk/openai/azure-ai-openai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "java",
"TagPrefix": "java/openai/azure-ai-openai",
"Tag": "java/openai/azure-ai-openai_2a6e71fe2e"
"Tag": "java/openai/azure-ai-openai_9fc7970110"
}
1 change: 1 addition & 0 deletions sdk/openai/azure-ai-openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
--add-exports com.azure.core/com.azure.core.implementation.util=ALL-UNNAMED
--add-opens com.azure.ai.openai/com.azure.ai.openai=ALL-UNNAMED
--add-opens com.azure.ai.openai/com.azure.ai.openai.implementation=com.fasterxml.jackson.databind
--add-opens com.azure.ai.openai/com.azure.ai.openai.functions=com.fasterxml.jackson.databind
</javaModulesSurefireArgLine>
<jacoco.skip>true</jacoco.skip>
</properties>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ public enum OpenAIServiceVersion implements ServiceVersion {
V2023_05_15("2023-05-15"),

/** Enum value 2023-06-01-preview. */
V2023_06_01_PREVIEW("2023-06-01-preview");
V2023_06_01_PREVIEW("2023-06-01-preview"),

/** Enum value 2023-07-01-preview. */
V2023_07_01_PREVIEW("2023-07-01-preview");

private final String version;

Expand All @@ -35,6 +38,6 @@ public String getVersion() {
* @return The latest {@link OpenAIServiceVersion}.
*/
public static OpenAIServiceVersion getLatest() {
return V2023_06_01_PREVIEW;
return V2023_07_01_PREVIEW;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

package com.azure.ai.openai.implementation;

import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.CompletionsOptions;
import com.azure.ai.openai.models.EmbeddingsOptions;
import com.azure.core.annotation.BodyParam;
import com.azure.core.annotation.ExpectedResponses;
import com.azure.core.annotation.HeaderParam;
Expand All @@ -28,8 +25,14 @@
import com.azure.core.util.Context;
import com.azure.core.util.FluxUtil;
import com.azure.core.util.serializer.SerializerAdapter;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import reactor.core.publisher.Mono;

import java.nio.charset.StandardCharsets;

/**
* Implementation for calling Non-Azure OpenAI Service
*/
Expand Down Expand Up @@ -66,6 +69,11 @@ public SerializerAdapter getSerializerAdapter() {
*/
public static final String OPEN_AI_ENDPOINT = "https://api.openai.com/v1";

/**
* Mapper used to add the `modelId` into the request body for an nonAzure OpenAI request
*/
private static final ObjectMapper JSON_MAPPER = new ObjectMapper();

/**
* Initializes an instance of OpenAIClient client.
*
Expand Down Expand Up @@ -289,20 +297,20 @@ public Mono<Response<BinaryData>> getEmbeddingsWithResponseAsync(String modelId,
BinaryData embeddingsOptions, RequestOptions requestOptions) {
final String accept = "application/json";

// OpenAI has model ID in request body
BinaryData embeddingsOptionsUpdated = BinaryData.fromObject(
embeddingsOptions.toObject(EmbeddingsOptions.class)
.setModel(modelId)
);

return FluxUtil.withContext(
context ->
service.getEmbeddings(
OPEN_AI_ENDPOINT,
accept,
embeddingsOptionsUpdated,
requestOptions,
context));
// modelId is part of the request body in nonAzure OpenAI
try {
BinaryData embeddingsOptionsUpdated = addModelIdJson(embeddingsOptions, modelId);
return FluxUtil.withContext(
context ->
service.getEmbeddings(
OPEN_AI_ENDPOINT,
accept,
embeddingsOptionsUpdated,
requestOptions,
context));
} catch (JsonProcessingException e) {
return Mono.error(e);
}
}

/**
Expand Down Expand Up @@ -357,18 +365,18 @@ public Response<BinaryData> getEmbeddingsWithResponse(String modelId, BinaryData
RequestOptions requestOptions) {
final String accept = "application/json";

// OpenAI has model ID in request body
BinaryData embeddingsOptionsUpdated = BinaryData.fromObject(
embeddingsOptions.toObject(EmbeddingsOptions.class)
.setModel(modelId)
);

return service.getEmbeddingsSync(
OPEN_AI_ENDPOINT,
accept,
embeddingsOptionsUpdated,
requestOptions,
Context.NONE);
// modelId is part of the request body in nonAzure OpenAI
try {
BinaryData embeddingsOptionsUpdated = addModelIdJson(embeddingsOptions, modelId);
return service.getEmbeddingsSync(
OPEN_AI_ENDPOINT,
accept,
embeddingsOptionsUpdated,
requestOptions,
Context.NONE);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}

/**
Expand Down Expand Up @@ -458,20 +466,20 @@ public Mono<Response<BinaryData>> getCompletionsWithResponseAsync(String modelId
BinaryData completionsOptions, RequestOptions requestOptions) {
final String accept = "application/json";

// OpenAI has model ID in request body
BinaryData completionsOptionsUpdated = BinaryData.fromObject(
completionsOptions.toObject(CompletionsOptions.class)
.setModel(modelId)
);

return FluxUtil.withContext(
context ->
service.getCompletions(
OPEN_AI_ENDPOINT,
accept,
completionsOptionsUpdated,
requestOptions,
context));
// modelId is part of the request body in nonAzure OpenAI
try {
BinaryData completionsOptionsUpdated = addModelIdJson(completionsOptions, modelId);
return FluxUtil.withContext(
context ->
service.getCompletions(
OPEN_AI_ENDPOINT,
accept,
completionsOptionsUpdated,
requestOptions,
context));
} catch (JsonProcessingException e) {
return Mono.error(e);
}
}

/**
Expand Down Expand Up @@ -559,11 +567,14 @@ public Response<BinaryData> getCompletionsWithResponse(String modelId, BinaryDat
RequestOptions requestOptions) {
final String accept = "application/json";

// OpenAI has model ID in request body
BinaryData completionsOptionsUpdated = BinaryData.fromObject(
completionsOptions.toObject(CompletionsOptions.class)
.setModel(modelId)
);
// modelId is part of the request body in nonAzure OpenAI
BinaryData completionsOptionsUpdated = null;
try {
completionsOptionsUpdated = addModelIdJson(completionsOptions, modelId);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}

return service.getCompletionsSync(
OPEN_AI_ENDPOINT,
accept,
Expand Down Expand Up @@ -650,20 +661,20 @@ public Mono<Response<BinaryData>> getChatCompletionsWithResponseAsync(String mod
BinaryData chatCompletionsOptions, RequestOptions requestOptions) {
final String accept = "application/json";

// OpenAI has model ID in request body
BinaryData chatCompletionsOptionsUpdated = BinaryData.fromObject(
chatCompletionsOptions.toObject(ChatCompletionsOptions.class)
.setModel(modelId)
);

return FluxUtil.withContext(
context ->
service.getChatCompletions(
OPEN_AI_ENDPOINT,
accept,
chatCompletionsOptionsUpdated,
requestOptions,
context));
// modelId is part of the request body in nonAzure OpenAI
try {
BinaryData chatCompletionsOptionsUpdated = addModelIdJson(chatCompletionsOptions, modelId);
return FluxUtil.withContext(
context ->
service.getChatCompletions(
OPEN_AI_ENDPOINT,
accept,
chatCompletionsOptionsUpdated,
requestOptions,
context));
} catch (JsonProcessingException e) {
return Mono.error(e);
}
}

/**
Expand Down Expand Up @@ -743,11 +754,13 @@ public Response<BinaryData> getChatCompletionsWithResponse(String modelId, Binar
RequestOptions requestOptions) {
final String accept = "application/json";

// OpenAI has model ID in request body
BinaryData chatCompletionsOptionsUpdated = BinaryData.fromObject(
chatCompletionsOptions.toObject(ChatCompletionsOptions.class)
.setModel(modelId)
);
// modelId is part of the request body in nonAzure OpenAI
BinaryData chatCompletionsOptionsUpdated = null;
try {
chatCompletionsOptionsUpdated = addModelIdJson(chatCompletionsOptions, modelId);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}

return service.getChatCompletionsSync(
OPEN_AI_ENDPOINT,
Expand Down Expand Up @@ -870,4 +883,26 @@ public Response<BinaryData> generateImageWithResponse(
Context.NONE
);
}

/**
* This method injects the modelId in the request body for requests against nonAzure OpenAI. Unlike Azure OpenAI,
* the service expects this value in the body of the request, whereas Azure OpenAI passes it as part of the
* path of the request.
*
* @param inputJson JSON submitted by the client
* @param modelId The LLM model ID to be injected in the JSON
* @return
*/
private static BinaryData addModelIdJson(BinaryData inputJson, String modelId) throws JsonProcessingException {
JsonNode jsonNode = JSON_MAPPER.readTree(inputJson.toString());
if (jsonNode instanceof ObjectNode) {
ObjectNode objectNode = (ObjectNode) jsonNode;
objectNode.put("model", modelId);
inputJson = BinaryData.fromBytes(
objectNode.toString()
.getBytes(StandardCharsets.UTF_8));
}

return inputJson;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ public Response<BinaryData> getEmbeddingsWithResponse(
* int (Required)
* ]
* }
* finish_reason: String(stop/length/content_filter) (Required)
* finish_reason: String(stop/length/content_filter/function_call) (Required)
* }
* ]
* usage (Required): {
Expand Down Expand Up @@ -646,7 +646,7 @@ public Mono<Response<BinaryData>> getCompletionsWithResponseAsync(
* int (Required)
* ]
* }
* finish_reason: String(stop/length/content_filter) (Required)
* finish_reason: String(stop/length/content_filter/function_call) (Required)
* }
* ]
* usage (Required): {
Expand Down Expand Up @@ -693,10 +693,23 @@ public Response<BinaryData> getCompletionsWithResponse(
* {
* messages (Required): [
* (Required){
* role: String(system/assistant/user) (Required)
* role: String(system/assistant/user/function) (Required)
* content: String (Optional)
* name: String (Optional)
* function_call (Optional): {
* name: String (Required)
* arguments: String (Required)
* }
* }
* ]
* functions (Optional): [
* (Optional){
* name: String (Required)
* description: String (Optional)
* parameters: Object (Optional)
* }
* ]
* function_call: FunctionCallModelBase (Optional)
* max_tokens: Integer (Optional)
* temperature: Double (Optional)
* top_p: Double (Optional)
Expand Down Expand Up @@ -724,11 +737,16 @@ public Response<BinaryData> getCompletionsWithResponse(
* choices (Required): [
* (Required){
* message (Optional): {
* role: String(system/assistant/user) (Required)
* role: String(system/assistant/user/function) (Required)
* content: String (Optional)
* name: String (Optional)
* function_call (Optional): {
* name: String (Required)
* arguments: String (Required)
* }
* }
* index: int (Required)
* finish_reason: String(stop/length/content_filter) (Required)
* finish_reason: String(stop/length/content_filter/function_call) (Required)
* delta (Optional): (recursive schema, see delta above)
* }
* ]
Expand Down Expand Up @@ -779,10 +797,23 @@ public Mono<Response<BinaryData>> getChatCompletionsWithResponseAsync(
* {
* messages (Required): [
* (Required){
* role: String(system/assistant/user) (Required)
* role: String(system/assistant/user/function) (Required)
* content: String (Optional)
* name: String (Optional)
* function_call (Optional): {
* name: String (Required)
* arguments: String (Required)
* }
* }
* ]
* functions (Optional): [
* (Optional){
* name: String (Required)
* description: String (Optional)
* parameters: Object (Optional)
* }
* ]
* function_call: FunctionCallModelBase (Optional)
* max_tokens: Integer (Optional)
* temperature: Double (Optional)
* top_p: Double (Optional)
Expand Down Expand Up @@ -810,11 +841,16 @@ public Mono<Response<BinaryData>> getChatCompletionsWithResponseAsync(
* choices (Required): [
* (Required){
* message (Optional): {
* role: String(system/assistant/user) (Required)
* role: String(system/assistant/user/function) (Required)
* content: String (Optional)
* name: String (Optional)
* function_call (Optional): {
* name: String (Required)
* arguments: String (Required)
* }
* }
* index: int (Required)
* finish_reason: String(stop/length/content_filter) (Required)
* finish_reason: String(stop/length/content_filter/function_call) (Required)
* delta (Optional): (recursive schema, see delta above)
* }
* ]
Expand Down
Loading