Skip to content

Commit

Permalink
[ Add ] support for log probs chat completion creation
Browse files Browse the repository at this point in the history
  • Loading branch information
anasfik committed Feb 22, 2024
1 parent e3973a7 commit dfce769
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"cSpell.words": ["Epoches", "openai"],
"cSpell.words": ["Epoches", "openai", "Probs"],
"editor.acceptSuggestionOnEnter": "off"
}
45 changes: 45 additions & 0 deletions example/lib/chat_completion_with_log_probs_example.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import 'package:dart_openai/dart_openai.dart';

import 'env/env.dart';

void main() async {
// Set the OpenAI API key from the .env file.
OpenAI.apiKey = Env.apiKey;

final systemMessage = OpenAIChatCompletionChoiceMessageModel(
content: [
OpenAIChatCompletionChoiceMessageContentItemModel.text(
"return any message you are given as JSON.",
),
],
role: OpenAIChatMessageRole.assistant,
);

final userMessage = OpenAIChatCompletionChoiceMessageModel(
content: [
OpenAIChatCompletionChoiceMessageContentItemModel.text(
"Hello, I am a chatbot created by OpenAI. How are you today?",
),
],
role: OpenAIChatMessageRole.user,
name: "anas",
);

final requestMessages = [
systemMessage,
userMessage,
];

OpenAIChatCompletionModel chatCompletion = await OpenAI.instance.chat.create(
model: "gpt-3.5-turbo-1106",
responseFormat: {"type": "json_object"},
seed: 6,
messages: requestMessages,
temperature: 0.2,
maxTokens: 500,
logprobs: true,
topLogprobs: 2,
);

print(chatCompletion.choices.first.logprobs?.content.first.bytes); //
}
9 changes: 9 additions & 0 deletions lib/src/core/models/chat/sub_models/choices/choices.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import 'sub_models/log_probs/log_probs.dart';
import 'sub_models/message.dart';

/// {@template openai_chat_completion_choice}
Expand All @@ -15,6 +16,9 @@ final class OpenAIChatCompletionChoiceModel {
/// The [finishReason] of the choice.
final String? finishReason;

/// The log probability of the choice.
final OpenAIChatCompletionChoiceLogProbsModel? logprobs;

/// Weither the choice have a finish reason.
bool get haveFinishReason => finishReason != null;

Expand All @@ -28,6 +32,7 @@ final class OpenAIChatCompletionChoiceModel {
required this.index,
required this.message,
required this.finishReason,
required this.logprobs,
});

/// This is used to convert a [Map<String, dynamic>] object to a [OpenAIChatCompletionChoiceModel] object.
Expand All @@ -39,6 +44,9 @@ final class OpenAIChatCompletionChoiceModel {
: int.tryParse(json['index'].toString()) ?? json['index'],
message: OpenAIChatCompletionChoiceMessageModel.fromMap(json['message']),
finishReason: json['finish_reason'],
logprobs: json['logprobs'] != null
? OpenAIChatCompletionChoiceLogProbsModel.fromMap(json['logprobs'])
: null,
);
}

Expand All @@ -48,6 +56,7 @@ final class OpenAIChatCompletionChoiceModel {
"index": index,
"message": message.toMap(),
"finish_reason": finishReason,
"logprobs": logprobs?.toMap(),
};
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// ignore_for_file: public_member_api_docs, sort_constructors_first
import 'sub_models/content.dart';

class OpenAIChatCompletionChoiceLogProbsModel {
OpenAIChatCompletionChoiceLogProbsModel({
required this.content,
});

final List<OpenAIChatCompletionChoiceLogProbsContentModel> content;

factory OpenAIChatCompletionChoiceLogProbsModel.fromMap(
Map<String, dynamic> json,
) {
return OpenAIChatCompletionChoiceLogProbsModel(
content: json["content"] != null
? List<OpenAIChatCompletionChoiceLogProbsContentModel>.from(
json["content"].map(
(x) =>
OpenAIChatCompletionChoiceLogProbsContentModel.fromMap(x),
),
)
: [],
);
}

Map<String, dynamic> toMap() {
return {
"content": content.map((x) => x.toMap()).toList(),
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import 'top_prob.dart';

class OpenAIChatCompletionChoiceLogProbsContentModel {
final String? token;

final double? logprob;

final List<int>? bytes;

final List<OpenAIChatCompletionChoiceTopLogProbsContentModel>? topLogprobs;

OpenAIChatCompletionChoiceLogProbsContentModel({
this.token,
this.logprob,
this.bytes,
this.topLogprobs,
});

factory OpenAIChatCompletionChoiceLogProbsContentModel.fromMap(
Map<String, dynamic> map,
) {
return OpenAIChatCompletionChoiceLogProbsContentModel(
token: map['token'],
logprob: map['logprob'],
bytes: List<int>.from(map['bytes']),
topLogprobs: List<OpenAIChatCompletionChoiceTopLogProbsContentModel>.from(
map['top_logprobs']?.map(
(x) => OpenAIChatCompletionChoiceTopLogProbsContentModel.fromMap(x),
),
),
);
}

Map<String, dynamic> toMap() {
return {
'token': token,
'logprob': logprob,
'bytes': bytes,
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import 'content.dart';

class OpenAIChatCompletionChoiceTopLogProbsContentModel
extends OpenAIChatCompletionChoiceLogProbsContentModel {
OpenAIChatCompletionChoiceTopLogProbsContentModel({
super.token,
super.logprob,
super.bytes,
});

factory OpenAIChatCompletionChoiceTopLogProbsContentModel.fromMap(
Map<String, dynamic> map,
) {
return OpenAIChatCompletionChoiceTopLogProbsContentModel(
token: map['token'],
logprob: map['logprob'],
bytes: List<int>.from(map['bytes']),
);
}

Map<String, dynamic> toMap() {
return {
'token': token,
'logprob': logprob,
'bytes': bytes,
};
}
}
4 changes: 4 additions & 0 deletions lib/src/instance/chat/chat.dart
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ interface class OpenAIChat implements OpenAIChatBase {
String? user,
Map<String, String>? responseFormat,
int? seed,
bool? logprobs,
int? topLogprobs,
http.Client? client,
}) async {
return await OpenAINetworkingClient.post(
Expand All @@ -103,6 +105,8 @@ interface class OpenAIChat implements OpenAIChatBase {
if (user != null) "user": user,
if (seed != null) "seed": seed,
if (responseFormat != null) "response_format": responseFormat,
if (logprobs != null) "logprobs": logprobs,
if (topLogprobs != null) "top_logprobs": topLogprobs,
},
onSuccess: (Map<String, dynamic> response) {
return OpenAIChatCompletionModel.fromMap(response);
Expand Down

0 comments on commit dfce769

Please sign in to comment.