Skip to content

Commit

Permalink
feat(android): added on-device model download
Browse files Browse the repository at this point in the history
mfkrause committed Nov 22, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 18b2b99 commit 91b89a5
Showing 6 changed files with 204 additions and 27 deletions.
21 changes: 20 additions & 1 deletion android/src/main/java/com/voicekit/VoiceKitModule.kt
Original file line number Diff line number Diff line change
@@ -68,7 +68,8 @@ class VoiceKitModule(reactContext: ReactApplicationContext) :
try {
voiceKitService.getSupportedLocales(reactApplicationContext) { locales ->
val writableArray = Arguments.createArray()
locales.forEach { writableArray.pushString(it) }
locales["installed"]?.forEach { writableArray.pushString(it) }
locales["supported"]?.forEach { writableArray.pushString(it) }
promise.resolve(writableArray)
}
} catch (e: Exception) {
@@ -77,6 +78,24 @@ class VoiceKitModule(reactContext: ReactApplicationContext) :
}
}

@ReactMethod
fun isOnDeviceModelInstalled(locale: String, promise: Promise) {
voiceKitService.getSupportedLocales(reactApplicationContext) { locales ->
promise.resolve(locales["installed"]?.contains(locale) ?: false)
}
}

@ReactMethod
fun downloadOnDeviceModel(locale: String, promise: Promise) {
voiceKitService.downloadOnDeviceModel(locale, { result ->
val response = Arguments.createMap().apply {
putString("status", result["status"] as String)
putBoolean("progressAvailable", result["progressAvailable"] as Boolean)
}
promise.resolve(response)
})
}

companion object {
const val NAME = "VoiceKit"
private const val TAG = "VoiceKitModule"
102 changes: 88 additions & 14 deletions android/src/main/java/com/voicekit/VoiceKitService.kt
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ import android.speech.RecognizerIntent
import android.speech.SpeechRecognizer
import android.speech.RecognitionSupport
import android.speech.RecognitionSupportCallback
import android.speech.ModelDownloadListener
import androidx.core.app.ActivityCompat
import androidx.core.content.ContextCompat
import com.facebook.react.bridge.*
@@ -43,6 +44,8 @@ class VoiceKitService(private val context: ReactApplicationContext) {
private var lastResultTimer: Handler? = null
private var lastTranscription: String? = null

private var isDownloadingModel: Boolean = false

fun sendEvent(eventName: String, params: Any?) {
context
.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter::class.java)
@@ -278,34 +281,32 @@ class VoiceKitService(private val context: ReactApplicationContext) {
}
}

fun getSupportedLocales(context: Context, callback: (List<String>) -> Unit) {
fun getSupportedLocales(context: Context, callback: (Map<String, List<String>>) -> Unit) {
Log.d(TAG, "Getting supported locales")

if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
// On Android 13+, we can get the list from the on-device recognizer
// TODO: The on-device supported locales are not necessarily the ones we can use for the standard recognizer
// We need to improve the usage of the default recognizer & on-device recognizer for both Android 13+ and <13
// On Android 13+, we can get a list of locales supported by the on-device recognizer

// On-device speech Recognizer can only be ran on main thread
// On-device speech recognizer can only be ran on main thread
Handler(context.mainLooper).post {
val tempSpeechRecognizer = SpeechRecognizer.createOnDeviceSpeechRecognizer(context)
val intent = Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH)

tempSpeechRecognizer?.checkRecognitionSupport(
tempSpeechRecognizer.checkRecognitionSupport(
intent,
Executors.newSingleThreadExecutor(),
@RequiresApi(Build.VERSION_CODES.TIRAMISU)
object : RecognitionSupportCallback {
override fun onSupportResult(recognitionSupport: RecognitionSupport) {
Log.d(TAG, "getSupportedLocales() onSupportResult called with recognitionSupport $recognitionSupport")

// TODO: We need a method to download supported but not installed locales, then we can send mergedLocales
val installedLocales = recognitionSupport.installedOnDeviceLanguages
val supportedLocales = recognitionSupport.supportedOnDeviceLanguages // not necessarily installed for on-device recognition

val mergedLocales = (installedLocales + supportedLocales).distinct().sorted()
val installedLocales = recognitionSupport.installedOnDeviceLanguages.map { it.toString() }
val supportedLocales = recognitionSupport.supportedOnDeviceLanguages.map { it.toString() }

callback(installedLocales?.map { it.toString() } ?: emptyList())
callback(mapOf(
"installed" to installedLocales,
"supported" to supportedLocales
))

tempSpeechRecognizer.destroy()
}
@@ -318,8 +319,81 @@ class VoiceKitService(private val context: ReactApplicationContext) {
)
}
} else {
// TODO: Implement fallback for Android <13
callback(emptyList())
callback(mapOf(
"installed" to emptyList(),
"supported" to emptyList()
))
}
}

fun downloadOnDeviceModel(locale: String, callback: (Map<String, Any>) -> Unit) {
if (isDownloadingModel) {
// throw VoiceError.InvalidState("A model download is already in progress")
throw VoiceError.InvalidState
}

if (Build.VERSION.SDK_INT < Build.VERSION_CODES.TIRAMISU) {
// throw VoiceError.InvalidState("Android version must be 13 or higher to download speech models")
throw VoiceError.InvalidState
}

val intent = Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH).apply {
putExtra(RecognizerIntent.EXTRA_LANGUAGE, locale)
}

// Android 13 does not support progress tracking, simply download the model
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
Handler(context.mainLooper).post {
val recognizer = SpeechRecognizer.createOnDeviceSpeechRecognizer(context)
recognizer.triggerModelDownload(intent)
recognizer.destroy()
callback(mapOf(
"status" to "started",
"progressAvailable" to false
))
}
return
}

// Android 14+ supports progress tracking, track download progress
isDownloadingModel = true
Handler(context.mainLooper).post {
val recognizer = SpeechRecognizer.createOnDeviceSpeechRecognizer(context)
recognizer.triggerModelDownload(
intent,
Executors.newSingleThreadExecutor(),
@RequiresApi(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
object : ModelDownloadListener {
override fun onProgress(progress: Int) {
sendEvent("RNVoiceKit.model-download-progress", progress)
}

override fun onSuccess() {
isDownloadingModel = false
recognizer.destroy()
}

override fun onScheduled() {
isDownloadingModel = false
/*callback(mapOf(
"status" to "scheduled",
"progressAvailable" to false
))*/
recognizer.destroy()
}

override fun onError(error: Int) {
isDownloadingModel = false
recognizer.destroy()
// throw VoiceError.RecognitionFailed("Model download failed with error code: $error")
throw VoiceError.RecognitionFailed // TODO: this doesn't reach the callback
}
}
)
callback(mapOf(
"status" to "started",
"progressAvailable" to true
))
}
}

47 changes: 41 additions & 6 deletions example/src/App.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { StyleSheet, View, Text, TouchableOpacity, TextInput } from 'react-native';
import { VoiceError, VoiceKit, VoiceMode, useVoice } from 'react-native-voicekit';
import { StyleSheet, View, Text, TouchableOpacity, TextInput, Platform } from 'react-native';
import { VoiceError, VoiceEvent, VoiceKit, VoiceMode, useVoice } from 'react-native-voicekit';
import Dropdown from './components/Dropdown';
import { useEffect, useState } from 'react';
import { useCallback, useEffect, useState } from 'react';

export default function App() {
const [locale, setLocale] = useState('en-US');
const [isLocaleInstalled, setIsLocaleInstalled] = useState(Platform.OS !== 'android');
const [supportedLocales, setSupportedLocales] = useState<string[]>([]);

const { available, listening, transcript, startListening, stopListening, resetTranscript } = useVoice({
@@ -22,27 +23,61 @@ export default function App() {
});
}, [locale]);

useEffect(() => {
VoiceKit.isOnDeviceModelInstalled(locale).then((isInstalled) => {
setIsLocaleInstalled(isInstalled);
});
}, [locale]);

const onModelDownloadProgress = useCallback((progress: number) => {
console.log('Model download progress:', progress);
if (progress >= 100) {
setIsLocaleInstalled(true);
VoiceKit.removeListener(VoiceEvent.ModelDownloadProgress, onModelDownloadProgress);
}
}, []);

return (
<View style={styles.container}>
<Text>Is available: {available ? 'Yes' : 'No'}</Text>
<Text style={{ marginBottom: 30 }}>Is listening: {listening ? 'Yes' : 'No'}</Text>
<Dropdown
label="Locale"
label={`Locale${Platform.OS === 'android' ? ` (is installed: ${isLocaleInstalled ? 'yes' : 'no'})` : ''}`}
data={supportedLocales.map((l) => ({ label: l, value: l }))}
maxHeight={300}
value={locale}
onChange={(item) => setLocale(item.value)}
containerStyle={styles.dropdown}
style={styles.dropdown}
/>
{Platform.OS === 'android' && (
<TouchableOpacity
onPress={() => {
VoiceKit.downloadOnDeviceModel(locale)
.then((result) => {
if (result.progressAvailable) {
VoiceKit.addListener(VoiceEvent.ModelDownloadProgress, onModelDownloadProgress);
} else {
console.log('Model download status:', result.status);
}
})
.catch((error) => {
console.error('Error downloading model', error, error instanceof VoiceError ? error.details : null);
});
}}
disabled={isLocaleInstalled}
style={[styles.button, isLocaleInstalled && styles.disabledButton]}>
<Text style={styles.buttonText}>Download "{locale}" Model</Text>
</TouchableOpacity>
)}
<TouchableOpacity
onPress={async () => {
await startListening().catch((error) => {
console.error('Error starting listening', error, error instanceof VoiceError ? error.details : null);
});
}}
disabled={!available || listening}
style={[styles.button, (!available || listening) && styles.disabledButton]}>
disabled={!available || !isLocaleInstalled || listening}
style={[styles.button, (!available || !isLocaleInstalled || listening) && styles.disabledButton]}>
<Text style={styles.buttonText}>Start Listening</Text>
</TouchableOpacity>
<TouchableOpacity
47 changes: 44 additions & 3 deletions src/RNVoiceKit.ts
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ import { NativeEventEmitter, NativeModules, Platform } from 'react-native';
import RNVoiceError from './utils/VoiceError';
import { VoiceErrorCode } from './types/native';
import { VoiceEvent, VoiceMode } from './types';
import type { VoiceEventMap, VoiceStartListeningOptions } from './types';
import { VoiceModelDownloadStatus, type VoiceEventMap, type VoiceStartListeningOptions } from './types';

const LINKING_ERROR =
`The package 'react-native-voicekit' doesn't seem to be linked. Make sure: \n\n` +
@@ -94,8 +94,9 @@ class RNVoiceKit {
}

/**
* Gets the list of supported locales for speech recognition. On Android, this gets the list of
* supported locales for the on-device speech recognizer.
* Gets the list of supported locales for speech recognition. On Android, this gets the list of supported locales for
* the on-device speech recognizer. Note that this does not check if the model is installed already. Use
* `isOnDeviceModelInstalled()` to check if the model for a given locale is installed before using it.
* Does not work on Android versions below 13 and will return an empty array for those versions.
*
* @returns The list of supported locales.
@@ -104,6 +105,46 @@ class RNVoiceKit {
return await nativeInstance.getSupportedLocales();
}

/**
* Checks if the on-device speech recognizer model for the given locale is installed. If it is not, use
* `downloadOnDeviceModel()` to download it. Only works on Android 13+.
* Does not have any effect on iOS and will simply check if the locale is supported.
*
* @param locale - The locale to check.
* @returns Whether the model is installed.
*/
async isOnDeviceModelInstalled(locale: string): Promise<boolean> {
if (Platform.OS === 'ios') {
return (await this.getSupportedLocales()).includes(locale);
}

return await nativeInstance.isOnDeviceModelInstalled(locale);
}

/**
* Downloads the on-device speech recognizer model for the given locale. Only works on Android 13+.
* When the download was successfully started, the promise will resolve with a `started` status.
* On Android 14+,you can listen to the `VoiceEvent.ModelDownloadProgress` event to track the download progress.
* Does not have any effect on iOS and will simply return a `started` status if the locale is supported, or throw
* an error if it is not.
*
* @returns The status of the model download and whether download progress is available via the
* `VoiceEvent.ModelDownloadProgress` event.
*/
async downloadOnDeviceModel(
locale: string
): Promise<{ status: VoiceModelDownloadStatus; progressAvailable: boolean }> {
if (Platform.OS === 'ios') {
if ((await this.getSupportedLocales()).includes(locale)) {
return { status: VoiceModelDownloadStatus.Started, progressAvailable: false };
} else {
throw new RNVoiceError('Locale is not supported', VoiceErrorCode.INVALID_STATE); // TODO: better code
}
}

return await nativeInstance.downloadOnDeviceModel(locale);
}

addListener<T extends VoiceEvent>(event: T, listener: (...args: VoiceEventMap[T]) => void) {
if (!this.listeners[event]) {
this.listeners[event] = [];
10 changes: 8 additions & 2 deletions src/types/index.ts
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ export enum VoiceEvent {
PartialResult = 'partial-result',
AvailabilityChange = 'availability-change',
ListeningStateChange = 'listening-state-change',
ModelDownloadProgress = 'model-download-progress',
Error = 'error',
}

@@ -11,6 +12,7 @@ export interface VoiceEventMap extends Record<VoiceEvent, any[]> {
[VoiceEvent.PartialResult]: [string];
[VoiceEvent.AvailabilityChange]: [boolean];
[VoiceEvent.ListeningStateChange]: [boolean];
[VoiceEvent.ModelDownloadProgress]: [number];
[VoiceEvent.Error]: any[];
}

@@ -20,6 +22,11 @@ export enum VoiceMode {
ContinuousAndStop = 'continuous-and-stop',
}

export enum VoiceModelDownloadStatus {
Started = 'started',
Scheduled = 'scheduled',
}

export interface VoiceStartListeningOptions {
/**
* The locale to use for speech recognition. Defaults to `en-US`.
@@ -57,8 +64,7 @@ export interface VoiceStartListeningOptions {
* Whether to force usage of the on-device speech recognizer. Does not have any effect on iOS. Only works on Android
* 13 and above. Defaults to `false`.
* Note: When using the on-device recognizer, some locales returned by `getSupportedLocales()` may not be installed
* on the device yet and need to be installed first.
* TODO: Add a method to install locales for the Android on-device recognizer
* on the device yet and need to be installed using `downloadOnDeviceModel()` first.
*/
useOnDeviceRecognizer?: boolean;
}
4 changes: 3 additions & 1 deletion src/types/native.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { VoiceStartListeningOptions } from '.';
import type { VoiceModelDownloadStatus, VoiceStartListeningOptions } from '.';

export enum VoiceErrorCode {
SPEECH_RECOGNIZER_NOT_AVAILABLE = 'ERR_SPEECH_RECOGNIZER_NOT_AVAILABLE',
@@ -15,5 +15,7 @@ export default interface NativeRNVoiceKit {
startListening: (options: Required<VoiceStartListeningOptions>) => Promise<void>;
stopListening: () => Promise<void>;
isSpeechRecognitionAvailable: () => Promise<boolean>;
isOnDeviceModelInstalled: (locale: string) => Promise<boolean>;
getSupportedLocales: () => Promise<string[]>;
downloadOnDeviceModel: (locale: string) => Promise<{ status: VoiceModelDownloadStatus; progressAvailable: boolean }>;
}

0 comments on commit 91b89a5

Please sign in to comment.