Skip to content

Commit

Permalink
Add happy-path test cases for Model and ApiClient (#16)
Browse files Browse the repository at this point in the history
Refactor `ApiClient` to go from `Map` to `Map` instead of bytes to
`String`. This makes it easier to write a stub client which is agnostic
to the ordering of values in the maps. Move the responsibilities for
UTF8 and JSON encoding to the `HttpApiClient`.

Eagerly call nested `.toJson()` methods for child fields in `.toJson()`
implementations. This allows deep equality checking for the stub client
without implementing `operator ==` for the data model classes.

Add a `createModelWithClient` method in `src/model.dart` which will not
be exported to the package public API. Use it in the test to create a
model with a stubbed client.

Add tests for each call in `GenerativeModel`. Tests validate
- Encoding from arguments to the JSON schema.
- Building the correct `Uri` for the task and model.
- Parsing the result to the Dart data model.

Add tests for unary and streaming calls in `HttpClient`. Tests validate
- Headers are set correctly.
- Encoding and decoding UTF8 and JSON.
- Using streaming with `alt=sse` query parameter and parsing of `data: `
  lines in the response.

Add Matcher helper methods for classes checked during tests that don't
have `operator ==`.
  • Loading branch information
natebosch authored Jan 30, 2024
1 parent 616a97c commit 7863b48
Show file tree
Hide file tree
Showing 9 changed files with 460 additions and 57 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/google_generative_ai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ jobs:
- run: dart format --output=none --set-exit-if-changed .
if: ${{matrix.run-tests}}

# - run: dart test
# if: ${{matrix.run-tests}}
- run: dart test
if: ${{matrix.run-tests}}
34 changes: 20 additions & 14 deletions pkgs/google_generative_ai/lib/src/client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,49 @@

import 'dart:async';
import 'dart:convert';
import 'dart:typed_data';

import 'package:http/http.dart' as http;

abstract interface class ApiClient {
Future<String> makeRequest(Uri uri, Uint8List body);
Stream<String> streamRequest(Uri uri, Uint8List body);
Future<Map<String, Object?>> makeRequest(Uri uri, Map<String, Object?> body);
Stream<Map<String, Object?>> streamRequest(
Uri uri, Map<String, Object?> body);
}

const _packageVersion = '0.0.1';
const _clientName = 'genai-dart/$_packageVersion';
const clientName = 'genai-dart/$_packageVersion';

final class HttpApiClient implements ApiClient {
final String _apiKey;
final http.Client? _httpClient;
late final _headers = {
'x-goog-api-key': _apiKey,
'x-goog-api-client': _clientName
'x-goog-api-client': clientName
};

HttpApiClient(
{required String model, required String apiKey, http.Client? httpClient})
HttpApiClient({required String apiKey, http.Client? httpClient})
: _apiKey = apiKey,
_httpClient = httpClient ?? http.Client();

@override
Future<String> makeRequest(Uri uri, Uint8List body) async {
final response = await http.post(uri,
headers: {..._headers, 'Content-Type': 'application/json'}, body: body);
return utf8.decode(response.bodyBytes);
Future<Map<String, Object?>> makeRequest(
Uri uri, Map<String, Object?> body) async {
final response = await http.post(
uri,
headers: {..._headers, 'Content-Type': 'application/json'},
body: utf8.encode(jsonEncode(body)),
);
return jsonDecode(utf8.decode(response.bodyBytes)) as Map<String, Object?>;
}

@override
Stream<String> streamRequest(Uri uri, Uint8List body) {
final controller = StreamController<String>();
Stream<Map<String, Object?>> streamRequest(
Uri uri, Map<String, Object?> body) {
final controller = StreamController<Map<String, Object?>>();
() async {
uri = uri.replace(queryParameters: {'alt': 'sse'});
final request = http.Request('POST', uri)
..bodyBytes = body
..bodyBytes = utf8.encode(jsonEncode(body))
..headers.addAll(_headers)
..headers['Content-Type'] = 'application/json';
final response = _httpClient == null
Expand All @@ -63,6 +67,8 @@ final class HttpApiClient implements ApiClient {
.transform(const LineSplitter())
.where((line) => line.startsWith('data: '))
.map((line) => line.substring(6))
.map(jsonDecode)
.cast<Map<String, Object?>>()
.pipe(controller);
}();
return controller.stream;
Expand Down
11 changes: 9 additions & 2 deletions pkgs/google_generative_ai/lib/src/content.dart
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ final class Content {
static Content multi(Iterable<Part> parts) => Content('user', [...parts]);
static Content model(Iterable<Part> parts) => Content('model', [...parts]);

Map toJson() => {if (role case final role?) 'role': role, 'parts': parts};
Map toJson() => {
if (role case final role?) 'role': role,
'parts': parts.map((p) => p.toJson()).toList()
};
}

Content parseContent(Object jsonObject) {
Expand All @@ -52,18 +55,22 @@ Part _parsePart(Object? jsonObject) {
}

/// A datatype containing media that is part of a multi-part [Content] message.
sealed class Part {}
sealed class Part {
Object toJson();
}

final class TextPart implements Part {
final String text;
TextPart(this.text);
@override
Object toJson() => {'text': text};
}

final class DataPart implements Part {
final String mimeType;
final Uint8List bytes;
DataPart(this.mimeType, this.bytes);
@override
Object toJson() => {
'inlineData': {'data': base64Encode(bytes), 'mimeType': mimeType}
};
Expand Down
98 changes: 60 additions & 38 deletions pkgs/google_generative_ai/lib/src/model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

import 'dart:async';
import 'dart:convert';

import 'package:http/http.dart' as http;

Expand Down Expand Up @@ -44,36 +43,35 @@ final class GenerativeModel {
final List<SafetySetting> _safetySettings;
final GenerationConfig? _generationConfig;
final ApiClient _client;
GenerativeModel(
{required String model,
required String apiKey,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
http.Client? httpClient})
:

factory GenerativeModel({
required String model,
required String apiKey,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
http.Client? httpClient,
}) =>
GenerativeModel._withClient(
client: HttpApiClient(apiKey: apiKey, httpClient: httpClient),
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig);

GenerativeModel._withClient({
required ApiClient client,
required String model,
required List<SafetySetting> safetySettings,
required GenerationConfig? generationConfig,
}) :
// TODO: Allow `models/` prefix and strip it.
// https://github.com/google/generative-ai-js/blob/2be48f8e5427f2f6191f24bcb8000b450715a0de/packages/main/src/models/generative-model.ts#L59
_model = model,
_safetySettings = safetySettings,
_generationConfig = generationConfig,
_client =
HttpApiClient(model: model, apiKey: apiKey, httpClient: httpClient);
_client = client;

Future<Object> _makeRequest(
Task task, Map<String, Object?> parameters) async {
final uri = _baseUrl.resolveUri(
Uri(pathSegments: [_apiVersion, 'models', '$_model:${task._name}']));
final body = utf8.encode(jsonEncode(parameters));
final jsonString = await _client.makeRequest(uri, body);
return jsonDecode(jsonString) as Object;
}

Stream<Map<String, Object?>> _streamRequest(Task task, Object parameters) {
final uri = _baseUrl.resolveUri(
Uri(pathSegments: [_apiVersion, 'models', '$_model:${task._name}']));
final body = utf8.encode(jsonEncode(parameters));
return _client.streamRequest(uri, body).map(jsonDecode).cast();
}
Uri _taskUri(Task task) => _baseUrl.resolveUri(
Uri(pathSegments: [_apiVersion, 'models', '$_model:${task._name}']));

/// Returns content responding to [prompt].
///
Expand All @@ -84,11 +82,14 @@ final class GenerativeModel {
Future<GenerateContentResponse> generateContent(
Iterable<Content> prompt) async {
final parameters = {
'contents': prompt.toList(),
if (_safetySettings.isNotEmpty) 'safetySettings': _safetySettings,
if (_generationConfig case final config?) 'generationConfig': config,
'contents': prompt.map((p) => p.toJson()).toList(),
if (_safetySettings.isNotEmpty)
'safetySettings': _safetySettings.map((s) => s.toJson()).toList(),
if (_generationConfig case final config?)
'generationConfig': config.toJson(),
};
final response = await _makeRequest(Task.generateContent, parameters);
final response =
await _client.makeRequest(_taskUri(Task.generateContent), parameters);
return parseGenerateContentResponse(response);
}

Expand All @@ -103,11 +104,14 @@ final class GenerativeModel {
Stream<GenerateContentResponse> generateContentStream(
Iterable<Content> prompt) {
final parameters = <String, Object?>{
'contents': prompt.toList(),
if (_safetySettings.isNotEmpty) 'safetySettings': _safetySettings,
if (_generationConfig case final config?) 'generationConfig': config,
'contents': prompt.map((p) => p.toJson()).toList(),
if (_safetySettings.isNotEmpty)
'safetySettings': _safetySettings.map((s) => s.toJson()).toList(),
if (_generationConfig case final config?)
'generationConfig': config.toJson(),
};
final response = _streamRequest(Task.streamGenerateContent, parameters);
final response =
_client.streamRequest(_taskUri(Task.streamGenerateContent), parameters);
return response.map(parseGenerateContentResponse);
}

Expand All @@ -124,8 +128,11 @@ final class GenerativeModel {
/// print(response.text);
/// }
Future<CountTokensResponse> countTokens(Iterable<Content> content) async {
final parameters = <String, Object?>{'contents': content.toList()};
final response = await _makeRequest(Task.countTokens, parameters);
final parameters = <String, Object?>{
'contents': content.map((c) => c.toJson()).toList()
};
final response =
await _client.makeRequest(_taskUri(Task.countTokens), parameters);
return parseCountTokensResponse(response);
}

Expand All @@ -138,11 +145,26 @@ final class GenerativeModel {
Future<EmbedContentResponse> embedContent(Content content,
{TaskType? taskType, String? title}) async {
final parameters = <String, Object?>{
'content': content,
if (taskType != null) 'taskType': taskType,
'content': content.toJson(),
if (taskType != null) 'taskType': taskType.toJson(),
if (title != null) 'title': title
};
final response = await _makeRequest(Task.embedContent, parameters);
final response =
await _client.makeRequest(_taskUri(Task.embedContent), parameters);
return parseEmbedContentReponse(response);
}
}

/// Create a model with an overridden [ApiClient] for testing.
///
/// Package private test only method.
GenerativeModel createModelwithClient(
{required String model,
required ApiClient client,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig}) =>
GenerativeModel._withClient(
client: client,
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig);
3 changes: 2 additions & 1 deletion pkgs/google_generative_ai/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ environment:
sdk: ^3.2.0

dependencies:
crypt: ^4.3.0
http: ^1.1.0
pool: ^1.5.0

dev_dependencies:
collection: ^1.18.0
lints: ^3.0.0
matcher: ^0.12.16
test: ^1.24.0
Loading

0 comments on commit 7863b48

Please sign in to comment.