Skip to content

Commit

Permalink
feat(chat-models): Support seed, system_fingerprint and JSON Mode in …
Browse files Browse the repository at this point in the history
…ChatOpenAI (#205)
  • Loading branch information
davidmigloz authored Nov 7, 2023
1 parent c31b679 commit 3332c22
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 3 deletions.
25 changes: 25 additions & 0 deletions packages/langchain_openai/lib/src/chat_models/models/mappers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import 'dart:convert';
import 'package:langchain/langchain.dart';
import 'package:openai_dart/openai_dart.dart';

import 'models.dart';

extension ChatMessageListMapper on List<ChatMessage> {
List<ChatCompletionMessage> toChatCompletionMessages() {
return map((final message) => message.toChatCompletionMessage())
Expand Down Expand Up @@ -48,6 +50,7 @@ extension CreateChatCompletionResponseMapper on CreateChatCompletionResponse {
modelOutput: {
'created': created,
'model': model,
'system_fingerprint': systemFingerprint,
},
);
}
Expand Down Expand Up @@ -88,6 +91,10 @@ extension _ChatCompletionMessageMapper on ChatCompletionMessage {
name: name ?? '',
content: content ?? '',
),
ChatCompletionMessageRole.tool => ChatMessage.function(
name: name ?? '',
content: content ?? '',
),
};
}
}
Expand Down Expand Up @@ -157,6 +164,7 @@ extension CreateChatCompletionStreamResponseMapper
'id': id,
'created': created,
'model': model,
'system_fingerprint': systemFingerprint,
},
streaming: true,
);
Expand Down Expand Up @@ -190,6 +198,10 @@ extension _ChatCompletionStreamResponseDeltaMapper
name: '',
content: content ?? '',
),
ChatCompletionMessageRole.tool => ChatMessage.function(
name: '',
content: content ?? '',
),
null => _handleNullRole(),
};
}
Expand Down Expand Up @@ -220,3 +232,16 @@ extension _ChatCompletionStreamMessageFunctionCallMapper
);
}
}

extension ChatOpenAIResponseFormatMapper on ChatOpenAIResponseFormat {
ChatCompletionResponseFormat toChatCompletionResponseFormat() {
return ChatCompletionResponseFormat(
type: switch (type) {
ChatOpenAIResponseFormatType.text =>
ChatCompletionResponseFormatType.text,
ChatOpenAIResponseFormatType.jsonObject =>
ChatCompletionResponseFormatType.jsonObject,
},
);
}
}
23 changes: 23 additions & 0 deletions packages/langchain_openai/lib/src/chat_models/models/models.dart
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,26 @@ class ChatOpenAIOptions extends ChatModelOptions {
/// Ref: https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids
final String? user;
}

/// {@template chat_openai_response_format}
/// An object specifying the format that the model must output.
/// {@endtemplate}
class ChatOpenAIResponseFormat {
/// {@macro chat_openai_response_format}
const ChatOpenAIResponseFormat({
required this.type,
});

/// The format type.
final ChatOpenAIResponseFormatType type;
}

/// Types of response formats.
enum ChatOpenAIResponseFormatType {
/// Standard text mode.
text,

/// [ChatOpenAIResponseFormatType.jsonObject] enables JSON mode, which
/// guarantees the message the model generates is valid JSON.
jsonObject,
}
33 changes: 33 additions & 0 deletions packages/langchain_openai/lib/src/chat_models/openai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
/// - [ChatOpenAI.maxTokens]
/// - [ChatOpenAI.n]
/// - [ChatOpenAI.presencePenalty]
/// - [ChatOpenAI.responseFormat]
/// - [ChatOpenAI.seed]
/// - [ChatOpenAI.temperature]
/// - [ChatOpenAI.topP]
/// - [ChatOpenAI.user]
Expand All @@ -113,6 +115,8 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
this.maxTokens,
this.n = 1,
this.presencePenalty = 0,
this.responseFormat,
this.seed,
this.temperature = 1,
this.topP = 1,
this.user,
Expand Down Expand Up @@ -163,6 +167,32 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
/// See https://platform.openai.com/docs/api-reference/chat/create#chat-create-presence_penalty
final double presencePenalty;

/// An object specifying the format that the model must output.
///
/// Setting to [ChatOpenAIResponseFormatType.jsonObject] enables JSON mode,
/// which guarantees the message the model generates is valid JSON.
///
/// Important: when using JSON mode you must still instruct the model to
/// produce JSON yourself via some conversation message, for example via your
/// system message. If you don't do this, the model may generate an unending
/// stream of whitespace until the generation reaches the token limit, which
/// may take a lot of time and give the appearance of a "stuck" request.
/// Also note that the message content may be partial (i.e. cut off) if
/// `finish_reason="length"`, which indicates the generation exceeded
/// `max_tokens` or the conversation exceeded the max context length.
///
/// See https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format
final ChatOpenAIResponseFormat? responseFormat;

/// This feature is in Beta. If specified, our system will make a best effort
/// to sample deterministically, such that repeated requests with the same
/// seed and parameters should return the same result. Determinism is not
/// guaranteed, and you should refer to the system_fingerprint response
/// parameter to monitor changes in the backend.
///
/// See https://platform.openai.com/docs/api-reference/chat/create#chat-create-seed
final int? seed;

/// What sampling temperature to use, between 0 and 2.
///
/// See https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature
Expand Down Expand Up @@ -214,6 +244,7 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
final messagesDtos = messages.toChatCompletionMessages();
final functionsDtos = options?.functions?.toChatCompletionFunctions();
final functionCall = options?.functionCall?.toChatCompletionFunctionCall();
final resFormat = responseFormat?.toChatCompletionResponseFormat();

final completion = await _client.createChatCompletion(
request: CreateChatCompletionRequest(
Expand All @@ -226,6 +257,8 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
maxTokens: maxTokens,
n: n,
presencePenalty: presencePenalty,
responseFormat: resFormat,
seed: seed,
stop: options?.stop != null
? ChatCompletionStop.arrayString(options!.stop!)
: null,
Expand Down
2 changes: 1 addition & 1 deletion packages/langchain_openai/lib/src/llms/models/mappers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ extension CreateCompletionResponseMapper on CreateCompletionResponse {
'id': id,
'created': created,
'model': model,
'systemFingerprint': systemFingerprint,
'system_fingerprint': systemFingerprint,
},
streaming: streaming,
);
Expand Down
1 change: 1 addition & 0 deletions packages/langchain_openai/lib/src/llms/openai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class OpenAI extends BaseLLM<OpenAIOptions> {
/// - [OpenAI.maxTokens]
/// - [OpenAI.n]
/// - [OpenAI.presencePenalty]
/// - [OpenAI.seed]
/// - [OpenAI.suffix]
/// - [OpenAI.temperature]
/// - [OpenAI.topP]
Expand Down
58 changes: 58 additions & 0 deletions packages/langchain_openai/test/chat_models/openai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -295,5 +295,63 @@ void main() {
expect(lastResult['setup'], isNotEmpty);
expect(lastResult['punchline'], isNotEmpty);
});

test('Test response seed', () async {
final prompt = PromptValue.string('How are you?');
final llm = ChatOpenAI(
apiKey: openaiApiKey,
temperature: 0,
seed: 9999,
);

final res1 = await llm.invoke(prompt);
expect(res1.generations, hasLength(1));
final generation1 = res1.generations.first;

final res2 = await llm.invoke(prompt);
expect(res2.generations, hasLength(1));
final generation2 = res2.generations.first;

expect(
res1.modelOutput?['system_fingerprint'],
res2.modelOutput?['system_fingerprint'],
);
expect(generation1.output, generation2.output);
});

test('Test JSON mode', () async {
final prompt = PromptValue.chat([
ChatMessage.system(
"Extract the 'name' and 'origin' of any companies mentioned in the "
'following statement. Return a JSON list.',
),
ChatMessage.human(
'Google was founded in the USA, while Deepmind was founded in the UK',
),
]);
final llm = ChatOpenAI(
apiKey: openaiApiKey,
model: 'gpt-4-1106-preview',
temperature: 0,
seed: 9999,
responseFormat: const ChatOpenAIResponseFormat(
type: ChatOpenAIResponseFormatType.jsonObject,
),
);

final res = await llm.invoke(prompt);
expect(res.generations, hasLength(1));
final outputMsg = res.generations.first.output;
final outputJson = json.decode(outputMsg.content) as Map<String, dynamic>;
expect(outputJson['companies'], isNotNull);
final companies = outputJson['companies'] as List<dynamic>;
expect(companies, hasLength(2));
final firstCompany = companies.first as Map<String, dynamic>;
expect(firstCompany['name'], 'Google');
expect(firstCompany['origin'], 'USA');
final secondCompany = companies.last as Map<String, dynamic>;
expect(secondCompany['name'], 'Deepmind');
expect(secondCompany['origin'], 'UK');
});
});
}
4 changes: 2 additions & 2 deletions packages/langchain_openai/test/llms/openai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ void main() {
final generation2 = res2.generations.first;

expect(
res1.modelOutput?['systemFingerprint'],
res2.modelOutput?['systemFingerprint'],
res1.modelOutput?['system_fingerprint'],
res2.modelOutput?['system_fingerprint'],
);
expect(generation1.output, generation2.output);
});
Expand Down

0 comments on commit 3332c22

Please sign in to comment.