Skip to content

Commit

Permalink
[OpenAI] Add sanitizer for multipart header and record without reques…
Browse files Browse the repository at this point in the history
…t body for multipart/form-data (Azure#36987)
  • Loading branch information
mssfang authored Sep 28, 2023
1 parent 78ec09d commit e17f1d9
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 10 deletions.
2 changes: 1 addition & 1 deletion sdk/openai/azure-ai-openai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "java",
"TagPrefix": "java/openai/azure-ai-openai",
"Tag": "java/openai/azure-ai-openai_2290060af1"
"Tag": "java/openai/azure-ai-openai_a191d5c4de"
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;

/**
* Helper class for marshaling {@link AudioTranscriptionOptions} and {@link AudioTranslationOptions} objects to be used
Expand Down Expand Up @@ -53,9 +54,7 @@ public class MultipartDataHelper {
* Default constructor used in the code. The boundary is a random value.
*/
public MultipartDataHelper() {
// TODO: We can't use randomly generated UUIDs for now. Generating a test session record won't match the
// newly generated UUID for the test run instance this(UUID.randomUUID().toString().substring(0, 16));
this("29580623-3d02-4a");
this(UUID.randomUUID().toString().substring(0, 16));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.azure.core.exception.HttpResponseException;
import com.azure.core.http.HttpClient;
import com.azure.core.http.rest.RequestOptions;
import com.azure.core.test.annotation.RecordWithoutRequestBody;
import com.azure.core.util.BinaryData;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
Expand Down Expand Up @@ -331,6 +332,7 @@ public void testCompletionContentFiltering(HttpClient httpClient, OpenAIServiceV

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionJson(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAIAsyncClient(httpClient);

Expand All @@ -348,6 +350,7 @@ public void testGetAudioTranscriptionJson(HttpClient httpClient, OpenAIServiceVe

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionVerboseJson(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAIAsyncClient(httpClient);

Expand All @@ -365,6 +368,7 @@ public void testGetAudioTranscriptionVerboseJson(HttpClient httpClient, OpenAISe

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionTextPlain(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAIAsyncClient(httpClient);

Expand All @@ -383,6 +387,7 @@ public void testGetAudioTranscriptionTextPlain(HttpClient httpClient, OpenAIServ

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionSrt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAIAsyncClient(httpClient);

Expand All @@ -405,6 +410,7 @@ public void testGetAudioTranscriptionSrt(HttpClient httpClient, OpenAIServiceVer

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionVtt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAIAsyncClient(httpClient);

Expand Down Expand Up @@ -470,6 +476,7 @@ public void testGetAudioTranscriptionJsonWrongFormats(HttpClient httpClient, Ope

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationJson(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAIAsyncClient(httpClient);

Expand All @@ -487,6 +494,7 @@ public void testGetAudioTranslationJson(HttpClient httpClient, OpenAIServiceVers

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationVerboseJson(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAIAsyncClient(httpClient);

Expand All @@ -504,6 +512,7 @@ public void testGetAudioTranslationVerboseJson(HttpClient httpClient, OpenAIServ

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationTextPlain(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAIAsyncClient(httpClient);

Expand All @@ -521,6 +530,7 @@ public void testGetAudioTranslationTextPlain(HttpClient httpClient, OpenAIServic

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationSrt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAIAsyncClient(httpClient);

Expand All @@ -543,6 +553,7 @@ public void testGetAudioTranslationSrt(HttpClient httpClient, OpenAIServiceVersi

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationVtt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAIAsyncClient(httpClient);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.azure.core.http.HttpClient;
import com.azure.core.http.rest.RequestOptions;
import com.azure.core.http.rest.Response;
import com.azure.core.test.annotation.RecordWithoutRequestBody;
import com.azure.core.util.BinaryData;
import com.azure.core.util.IterableStream;
import org.junit.jupiter.params.ParameterizedTest;
Expand Down Expand Up @@ -285,6 +286,7 @@ public void testCompletionContentFiltering(HttpClient httpClient, OpenAIServiceV

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionJson(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAISyncClient(httpClient);

Expand All @@ -300,6 +302,7 @@ public void testGetAudioTranscriptionJson(HttpClient httpClient, OpenAIServiceVe

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionVerboseJson(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAISyncClient(httpClient);

Expand All @@ -315,6 +318,7 @@ public void testGetAudioTranscriptionVerboseJson(HttpClient httpClient, OpenAISe

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionTextPlain(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAISyncClient(httpClient);

Expand All @@ -331,6 +335,7 @@ public void testGetAudioTranscriptionTextPlain(HttpClient httpClient, OpenAIServ

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionSrt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAISyncClient(httpClient);

Expand All @@ -351,6 +356,7 @@ public void testGetAudioTranscriptionSrt(HttpClient httpClient, OpenAIServiceVer

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionVtt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAISyncClient(httpClient);

Expand Down Expand Up @@ -416,6 +422,7 @@ public void testGetAudioTranscriptionJsonWrongFormats(HttpClient httpClient, Ope

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationJson(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAISyncClient(httpClient);

Expand All @@ -431,6 +438,7 @@ public void testGetAudioTranslationJson(HttpClient httpClient, OpenAIServiceVers

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationVerboseJson(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAISyncClient(httpClient);

Expand All @@ -446,6 +454,7 @@ public void testGetAudioTranslationVerboseJson(HttpClient httpClient, OpenAIServ

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationTextPlain(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAISyncClient(httpClient);

Expand All @@ -461,6 +470,7 @@ public void testGetAudioTranslationTextPlain(HttpClient httpClient, OpenAIServic

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationSrt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAISyncClient(httpClient);

Expand All @@ -481,6 +491,7 @@ public void testGetAudioTranslationSrt(HttpClient httpClient, OpenAIServiceVersi

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationVtt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getNonAzureOpenAISyncClient(httpClient);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.azure.core.exception.ResourceNotFoundException;
import com.azure.core.http.HttpClient;
import com.azure.core.http.rest.RequestOptions;
import com.azure.core.test.annotation.RecordWithoutRequestBody;
import com.azure.core.util.BinaryData;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
Expand Down Expand Up @@ -409,9 +410,6 @@ public void testCompletionStreamContentFiltering(HttpClient httpClient, OpenAISe
// The last stream message is empty with all the filters set to null
assertEquals(1, completions.getChoices().size());
Choice choice = completions.getChoices().get(0);
// TODO: service issue: we could have "length" as the finish reason.
// Non-Streaming happens less frequency than streaming API.
// https://github.com/Azure/azure-sdk-for-java/issues/36894
assertEquals(CompletionsFinishReason.fromString("stop"), choice.getFinishReason());
assertNotNull(choice.getText());

Expand Down Expand Up @@ -484,6 +482,7 @@ public void testChatCompletionsStreamingBasicSearchExtension(HttpClient httpClie

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionJson(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getOpenAIAsyncClient(httpClient, serviceVersion);

Expand All @@ -501,6 +500,7 @@ public void testGetAudioTranscriptionJson(HttpClient httpClient, OpenAIServiceVe

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionVerboseJson(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getOpenAIAsyncClient(httpClient, serviceVersion);

Expand All @@ -518,6 +518,7 @@ public void testGetAudioTranscriptionVerboseJson(HttpClient httpClient, OpenAISe

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionTextPlain(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getOpenAIAsyncClient(httpClient, serviceVersion);

Expand All @@ -536,6 +537,7 @@ public void testGetAudioTranscriptionTextPlain(HttpClient httpClient, OpenAIServ

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionSrt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getOpenAIAsyncClient(httpClient, serviceVersion);

Expand All @@ -558,6 +560,7 @@ public void testGetAudioTranscriptionSrt(HttpClient httpClient, OpenAIServiceVer

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranscriptionVtt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getOpenAIAsyncClient(httpClient, serviceVersion);

Expand Down Expand Up @@ -623,6 +626,7 @@ public void testGetAudioTranscriptionJsonWrongFormats(HttpClient httpClient, Ope

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationJson(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getOpenAIAsyncClient(httpClient, serviceVersion);

Expand All @@ -640,6 +644,7 @@ public void testGetAudioTranslationJson(HttpClient httpClient, OpenAIServiceVers

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationVerboseJson(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getOpenAIAsyncClient(httpClient, serviceVersion);

Expand All @@ -657,6 +662,7 @@ public void testGetAudioTranslationVerboseJson(HttpClient httpClient, OpenAIServ

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationTextPlain(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getOpenAIAsyncClient(httpClient, serviceVersion);

Expand All @@ -674,6 +680,7 @@ public void testGetAudioTranslationTextPlain(HttpClient httpClient, OpenAIServic

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationSrt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getOpenAIAsyncClient(httpClient, serviceVersion);

Expand All @@ -696,6 +703,7 @@ public void testGetAudioTranslationSrt(HttpClient httpClient, OpenAIServiceVersi

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
@RecordWithoutRequestBody
public void testGetAudioTranslationVtt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getOpenAIAsyncClient(httpClient, serviceVersion);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ OpenAIClientBuilder getNonAzureOpenAIClientBuilder(HttpClient httpClient) {
private void addTestRecordCustomSanitizers() {
interceptorManager.addSanitizers(Arrays.asList(
new TestProxySanitizer("$..key", null, "REDACTED", TestProxySanitizerType.BODY_KEY),
new TestProxySanitizer("$..endpoint", null, "https://REDACTED", TestProxySanitizerType.BODY_KEY)
new TestProxySanitizer("$..endpoint", null, "https://REDACTED", TestProxySanitizerType.BODY_KEY),
new TestProxySanitizer("Content-Type", "(^multipart\\/form-data; boundary=[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{2})",
"multipart\\/form-data; boundary=BOUNDARY", TestProxySanitizerType.HEADER)
));
}

Expand Down
Loading

0 comments on commit e17f1d9

Please sign in to comment.