Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class _BidiPageState extends State<BidiPage> {
StreamController<bool> _stopController = StreamController<bool>();
final AudioOutput _audioOutput = AudioOutput();
final AudioInput _audioInput = AudioInput();
int? _inputTranscriptionMessageIndex;
int? _outputTranscriptionMessageIndex;

@override
void initState() {
Expand All @@ -67,6 +69,8 @@ class _BidiPageState extends State<BidiPage> {
responseModalities: [
ResponseModalities.audio,
],
inputAudioTranscription: AudioTranscriptionConfig(),
outputAudioTranscription: AudioTranscriptionConfig(),
);

// ignore: deprecated_member_use
Expand Down Expand Up @@ -353,6 +357,49 @@ class _BidiPageState extends State<BidiPage> {
if (message.modelTurn != null) {
await _handleLiveServerContent(message);
}

int? _handleTranscription(
Transcription? transcription,
int? messageIndex,
String prefix,
bool fromUser,
) {
int? currentIndex = messageIndex;
if (transcription?.text != null) {
if (currentIndex != null) {
_messages[currentIndex] = _messages[currentIndex].copyWith(
text: '${_messages[currentIndex].text}${transcription!.text!}',
);
} else {
_messages.add(
MessageData(
text: '$prefix${transcription!.text!}',
fromUser: fromUser,
),
);
currentIndex = _messages.length - 1;
}
if (transcription.finished ?? false) {
currentIndex = null;
}
setState(_scrollDown);
}
return currentIndex;
}

_inputTranscriptionMessageIndex = _handleTranscription(
message.inputTranscription,
_inputTranscriptionMessageIndex,
'Input transcription: ',
true,
);
_outputTranscriptionMessageIndex = _handleTranscription(
message.outputTranscription,
_outputTranscriptionMessageIndex,
'Output transcription: ',
false,
);

if (message.interrupted != null && message.interrupted!) {
developer.log('Interrupted: $response');
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ class MessageData {
this.fromUser,
this.isThought = false,
});

MessageData copyWith({
Uint8List? imageBytes,
String? text,
bool? fromUser,
bool? isThought,
}) {
return MessageData(
imageBytes: imageBytes ?? this.imageBytes,
text: text ?? this.text,
fromUser: fromUser ?? this.fromUser,
isThought: isThought ?? this.isThought,
);
}

final Uint8List? imageBytes;
final String? text;
final bool? fromUser;
Expand Down
4 changes: 3 additions & 1 deletion packages/firebase_ai/firebase_ai/lib/firebase_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ export 'src/live_api.dart'
show
LiveGenerationConfig,
SpeechConfig,
AudioTranscriptionConfig,
LiveServerMessage,
LiveServerContent,
LiveServerToolCall,
LiveServerToolCallCancellation,
LiveServerResponse;
LiveServerResponse,
Transcription;
export 'src/live_session.dart' show LiveSession;
export 'src/schema.dart' show Schema, SchemaType;
export 'src/tool.dart'
Expand Down
69 changes: 67 additions & 2 deletions packages/firebase_ai/firebase_ai/lib/src/live_api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,19 @@ class SpeechConfig {
};
}

/// The audio transcription configuration.
class AudioTranscriptionConfig {
// ignore: public_member_api_docs
Map<String, Object?> toJson() => {};
}

/// Configures live generation settings.
final class LiveGenerationConfig extends BaseGenerationConfig {
// ignore: public_member_api_docs
LiveGenerationConfig({
this.speechConfig,
this.inputAudioTranscription,
this.outputAudioTranscription,
super.responseModalities,
super.maxOutputTokens,
super.temperature,
Expand All @@ -88,6 +96,13 @@ final class LiveGenerationConfig extends BaseGenerationConfig {
/// The speech configuration.
final SpeechConfig? speechConfig;

/// The transcription of the input aligns with the input audio language.
final AudioTranscriptionConfig? inputAudioTranscription;

/// The transcription of the output aligns with the language code specified for
/// the output audio.
final AudioTranscriptionConfig? outputAudioTranscription;

@override
Map<String, Object?> toJson() => {
...super.toJson(),
Expand All @@ -109,14 +124,33 @@ sealed class LiveServerMessage {}
/// with the live server has finished successfully.
class LiveServerSetupComplete implements LiveServerMessage {}

/// Audio transcription message.
class Transcription {
// ignore: public_member_api_docs
const Transcription({this.text, this.finished});

/// Transcription text.
final String? text;

/// Whether this is the end of the transcription.
final bool? finished;
}

/// Content generated by the model in a live stream.
class LiveServerContent implements LiveServerMessage {
/// Creates a [LiveServerContent] instance.
///
/// [modelTurn] (optional): The content generated by the model.
/// [turnComplete] (optional): Indicates if the turn is complete.
/// [interrupted] (optional): Indicates if the generation was interrupted.
LiveServerContent({this.modelTurn, this.turnComplete, this.interrupted});
/// [inputTranscription] (optional): The input transcription.
/// [outputTranscription] (optional): The output transcription.
LiveServerContent(
{this.modelTurn,
this.turnComplete,
this.interrupted,
this.inputTranscription,
this.outputTranscription});

// TODO(cynthia): Add accessor for media content
/// The content generated by the model.
Expand All @@ -129,6 +163,18 @@ class LiveServerContent implements LiveServerMessage {
/// Whether generation was interrupted. If true, indicates that a
/// client message has interrupted current model
final bool? interrupted;

/// The input transcription.
///
/// The transcription is independent to the model turn which means it doesn't
/// imply any ordering between transcription and model turn.
final Transcription? inputTranscription;

/// The output transcription.
///
/// The transcription is independent to the model turn which means it doesn't
/// imply any ordering between transcription and model turn.
final Transcription? outputTranscription;
}

/// A tool call in a live stream.
Expand Down Expand Up @@ -344,7 +390,26 @@ LiveServerMessage _parseServerMessage(Object jsonObject) {
if (serverContentJson.containsKey('turnComplete')) {
turnComplete = serverContentJson['turnComplete'] as bool;
}
return LiveServerContent(modelTurn: modelTurn, turnComplete: turnComplete);
final interrupted = serverContentJson['interrupted'] as bool?;
Transcription? _parseTranscription(String key) {
if (serverContentJson.containsKey(key)) {
final transcriptionJson =
serverContentJson[key] as Map<String, dynamic>;
return Transcription(
text: transcriptionJson['text'] as String?,
finished: transcriptionJson['finished'] as bool?,
);
}
return null;
}

return LiveServerContent(
modelTurn: modelTurn,
turnComplete: turnComplete,
interrupted: interrupted,
inputTranscription: _parseTranscription('inputTranscription'),
outputTranscription: _parseTranscription('outputTranscription'),
);
} else if (json.containsKey('toolCall')) {
final toolContentJson = json['toolCall'] as Map<String, dynamic>;
List<FunctionCall> functionCalls = [];
Expand Down
11 changes: 9 additions & 2 deletions packages/firebase_ai/firebase_ai/lib/src/live_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,18 @@ final class LiveGenerativeModel extends BaseModel {
final setupJson = {
'setup': {
'model': modelString,
if (_liveGenerationConfig != null)
'generation_config': _liveGenerationConfig.toJson(),
if (_systemInstruction != null)
'system_instruction': _systemInstruction.toJson(),
if (_tools != null) 'tools': _tools.map((t) => t.toJson()).toList(),
if (_liveGenerationConfig != null) ...{
'generation_config': _liveGenerationConfig.toJson(),
if (_liveGenerationConfig.inputAudioTranscription != null)
'input_audio_transcription':
_liveGenerationConfig.inputAudioTranscription!.toJson(),
if (_liveGenerationConfig.outputAudioTranscription != null)
'output_audio_transcription':
_liveGenerationConfig.outputAudioTranscription!.toJson(),
},
}
};

Expand Down
37 changes: 37 additions & 0 deletions packages/firebase_ai/firebase_ai/test/live_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -240,5 +240,42 @@ void main() {
expect(() => parseServerResponse(jsonObject),
throwsA(isA<FirebaseAISdkException>()));
});

test(
'LiveGenerationConfig with transcriptions toJson() returns correct JSON',
() {
final liveGenerationConfig = LiveGenerationConfig(
inputAudioTranscription: AudioTranscriptionConfig(),
outputAudioTranscription: AudioTranscriptionConfig(),
);
// Explicitly, these two config should not exist in the toJson()
expect(liveGenerationConfig.toJson(), {});
});

test('parseServerMessage parses serverContent with transcriptions', () {
final jsonObject = {
'serverContent': {
'modelTurn': {
'parts': [
{'text': 'Hello, world!'}
]
},
'turnComplete': true,
'inputTranscription': {'text': 'input', 'finished': true},
'outputTranscription': {'text': 'output', 'finished': false}
}
};
final response = parseServerResponse(jsonObject);
expect(response.message, isA<LiveServerContent>());
final contentMessage = response.message as LiveServerContent;
expect(contentMessage.turnComplete, true);
expect(contentMessage.modelTurn, isA<Content>());
expect(contentMessage.inputTranscription, isA<Transcription>());
expect(contentMessage.inputTranscription?.text, 'input');
expect(contentMessage.inputTranscription?.finished, true);
expect(contentMessage.outputTranscription, isA<Transcription>());
expect(contentMessage.outputTranscription?.text, 'output');
expect(contentMessage.outputTranscription?.finished, false);
});
});
}
Loading