Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable to stop TTS generation #1041

Merged
merged 5 commits into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ project(sherpa-onnx)
# ./nodejs-addon-examples
# ./dart-api-examples/
# ./sherpa-onnx/flutter/CHANGELOG.md
set(SHERPA_ONNX_VERSION "1.10.0")
set(SHERPA_ONNX_VERSION "1.10.1")

# Disable warning about
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class MainActivity : AppCompatActivity() {
private lateinit var speed: EditText
private lateinit var generate: Button
private lateinit var play: Button
private lateinit var stop: Button
private var stopped: Boolean = false
private var mediaPlayer: MediaPlayer? = null

// see
// https://developer.android.com/reference/kotlin/android/media/AudioTrack
Expand All @@ -49,9 +52,11 @@ class MainActivity : AppCompatActivity() {

generate = findViewById(R.id.generate)
play = findViewById(R.id.play)
stop = findViewById(R.id.stop)

generate.setOnClickListener { onClickGenerate() }
play.setOnClickListener { onClickPlay() }
stop.setOnClickListener { onClickStop() }

sid.setText("0")
speed.setText("1.0")
Expand All @@ -70,7 +75,7 @@ class MainActivity : AppCompatActivity() {
AudioFormat.CHANNEL_OUT_MONO,
AudioFormat.ENCODING_PCM_FLOAT
)
Log.i(TAG, "sampleRate: ${sampleRate}, buffLength: ${bufLength}")
Log.i(TAG, "sampleRate: $sampleRate, buffLength: $bufLength")

val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.setUsage(AudioAttributes.USAGE_MEDIA)
Expand All @@ -90,8 +95,14 @@ class MainActivity : AppCompatActivity() {
}

// this function is called from C++
private fun callback(samples: FloatArray) {
track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
private fun callback(samples: FloatArray): Int {
if (!stopped) {
track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
return 1
} else {
track.stop()
return 0
}
}

private fun onClickGenerate() {
Expand Down Expand Up @@ -127,6 +138,8 @@ class MainActivity : AppCompatActivity() {
track.play()

play.isEnabled = false
generate.isEnabled = false
stopped = false
Thread {
val audio = tts.generateWithCallback(
text = textStr,
Expand All @@ -140,6 +153,7 @@ class MainActivity : AppCompatActivity() {
if (ok) {
runOnUiThread {
play.isEnabled = true
generate.isEnabled = true
track.stop()
}
}
Expand All @@ -148,11 +162,22 @@ class MainActivity : AppCompatActivity() {

private fun onClickPlay() {
val filename = application.filesDir.absolutePath + "/generated.wav"
val mediaPlayer = MediaPlayer.create(
mediaPlayer?.stop()
mediaPlayer = MediaPlayer.create(
applicationContext,
Uri.fromFile(File(filename))
)
mediaPlayer.start()
mediaPlayer?.start()
}

private fun onClickStop() {
stopped = true
play.isEnabled = true
generate.isEnabled = true
track.pause()
track.flush()
mediaPlayer?.stop()
mediaPlayer = null
}

private fun initTts() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class OfflineTts(
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Unit
callback: (samples: FloatArray) -> Int
): GeneratedAudio {
val objArray = generateWithCallbackImpl(
ptr,
Expand Down Expand Up @@ -146,7 +146,7 @@ class OfflineTts(
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Unit
callback: (samples: FloatArray) -> Int
): Array<Any>

companion object {
Expand Down
12 changes: 12 additions & 0 deletions android/SherpaOnnxTts/app/src/main/res/layout/activity_main.xml
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,16 @@
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toBottomOf="@id/generate" />

<Button
android:id="@+id/stop"
android:textAllCaps="false"
android:layout_width="match_parent"
android:layout_height="50dp"
android:layout_marginTop="4dp"
android:text="@string/stop"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toBottomOf="@id/play" />

</androidx.constraintlayout.widget.ConstraintLayout>
1 change: 1 addition & 0 deletions android/SherpaOnnxTts/app/src/main/res/values/strings.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
<string name="text_hint">Please input your text here</string>
<string name="generate">Generate</string>
<string name="play">Play</string>
<string name="stop">Stop</string>
</resources>
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class TtsService : TextToSpeechService() {
return
}

val ttsCallback = { floatSamples: FloatArray ->
val ttsCallback: (FloatArray) -> Int = fun(floatSamples): Int {
// convert FloatArray to ByteArray
val samples = floatArrayToByteArray(floatSamples)
val maxBufferSize: Int = callback.maxBufferSize
Expand All @@ -137,6 +137,9 @@ class TtsService : TextToSpeechService() {
offset += bytesToWrite
}

// 1 means to continue
// 0 means to stop
return 1
}

Log.i(TAG, "text: $text")
Expand All @@ -160,4 +163,4 @@ class TtsService : TextToSpeechService() {
}
return byteArray
}
}
}
2 changes: 1 addition & 1 deletion dart-api-examples/non-streaming-asr/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ environment:

# Add regular dependencies here.
dependencies:
sherpa_onnx: ^1.10.0
sherpa_onnx: ^1.10.1
path: ^1.9.0
args: ^2.5.0

Expand Down
2 changes: 1 addition & 1 deletion dart-api-examples/streaming-asr/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ environment:

# Add regular dependencies here.
dependencies:
sherpa_onnx: ^1.10.0
sherpa_onnx: ^1.10.1
path: ^1.9.0
args: ^2.5.0

Expand Down
4 changes: 4 additions & 0 deletions dart-api-examples/tts/bin/piper.dart
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ void main(List<String> arguments) async {
callback: (Float32List samples) {
print('${samples.length} samples received');
// You can play samples in a separate thread/isolate

// 1 means to continue
// 0 means to stop
return 1;
});
tts.free();

Expand Down
2 changes: 1 addition & 1 deletion dart-api-examples/tts/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ environment:

# Add regular dependencies here.
dependencies:
sherpa_onnx: ^1.10.0
sherpa_onnx: ^1.10.1
path: ^1.9.0
args: ^2.5.0

Expand Down
2 changes: 1 addition & 1 deletion dart-api-examples/vad/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ environment:
sdk: ^3.4.0

dependencies:
sherpa_onnx: ^1.10.0
sherpa_onnx: ^1.10.1
path: ^1.9.0
args: ^2.5.0

Expand Down
4 changes: 4 additions & 0 deletions dotnet-examples/offline-tts-play/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ private static void Run(Options options)
Marshal.Copy(samples, data, 0, n);

dataItems.Add(data);

// 1 means to keep generating
// 0 means to stop generating
return 1;
};

bool playFinished = false;
Expand Down
42 changes: 41 additions & 1 deletion kotlin-api-examples/test_tts.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,46 @@ fun testTts() {
println("Saved to test-en.wav")
}

fun callback(samples: FloatArray): Unit {
/*
1. Unzip test_tts.jar
2.
javap ./com/k2fsa/sherpa/onnx/Test_ttsKt\$testTts\$audio\$1.class

3. It prints:
Compiled from "test_tts.kt"
final class com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 extends kotlin.jvm.internal.FunctionReferenceImpl implements kotlin.jvm.functions.Function1<float[], java.lang.Integer> {
public static final com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 INSTANCE;
com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1();
public final java.lang.Integer invoke(float[]);
public java.lang.Object invoke(java.lang.Object);
static {};
}

4.
javap -s ./com/k2fsa/sherpa/onnx/Test_ttsKt\$testTts\$audio\$1.class

5. It prints
Compiled from "test_tts.kt"
final class com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 extends kotlin.jvm.internal.FunctionReferenceImpl implements kotlin.jvm.functions.Function1<float[], java.lang.Integer> {
public static final com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 INSTANCE;
descriptor: Lcom/k2fsa/sherpa/onnx/Test_ttsKt$testTts$audio$1;
com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1();
descriptor: ()V

public final java.lang.Integer invoke(float[]);
descriptor: ([F)Ljava/lang/Integer;

public java.lang.Object invoke(java.lang.Object);
descriptor: (Ljava/lang/Object;)Ljava/lang/Object;

static {};
descriptor: ()V
}
*/
fun callback(samples: FloatArray): Int {
println("callback got called with ${samples.size} samples");

// 1 means to continue
// 0 means to stop
return 1
}
Binary file modified mfc-examples/NonStreamingTextToSpeech/NonStreamingTextToSpeech.rc
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ static bool g_started = false;
static bool g_stopped = false;
static bool g_killed = false;

static void AudioGeneratedCallback(const float *s, int32_t n) {
static int32_t AudioGeneratedCallback(const float *s, int32_t n) {
if (n > 0) {
Samples samples;
samples.data = std::vector<float>{s, s + n};
Expand All @@ -66,6 +66,10 @@ static void AudioGeneratedCallback(const float *s, int32_t n) {
g_buffer.samples.push(std::move(samples));
g_started = true;
}
if (g_killed) {
return 0;
}
return 1;
}

static int PlayCallback(const void * /*in*/, void *out,
Expand Down Expand Up @@ -324,6 +328,7 @@ BEGIN_MESSAGE_MAP(CNonStreamingTextToSpeechDlg, CDialogEx)
ON_WM_PAINT()
ON_WM_QUERYDRAGICON()
ON_BN_CLICKED(IDOK, &CNonStreamingTextToSpeechDlg::OnBnClickedOk)
ON_BN_CLICKED(IDC_STOP, &CNonStreamingTextToSpeechDlg::OnBnClickedStop)
END_MESSAGE_MAP()


Expand Down Expand Up @@ -492,11 +497,18 @@ void CNonStreamingTextToSpeechDlg::Init() {
if (tts_) {
SherpaOnnxDestroyOfflineTts(tts_);
}
if (generate_thread_ && generate_thread_->joinable()) {
generate_thread_->join();
}

if (play_thread_ && play_thread_->joinable()) {
play_thread_->join();
}
}


static std::string ToString(const CString &s) {
CT2CA pszConvertedAnsiString( s);
CT2CA pszConvertedAnsiString(s);
return std::string(pszConvertedAnsiString);
}

Expand All @@ -510,7 +522,7 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() {
}

speed_.GetWindowText(s);
float speed = static_cast<float>(_ttof(s));
float speed = static_cast<float>(_ttof(s));
if (speed < 0) {
AfxMessageBox(Utf8ToUtf16("Please input a valid speed").c_str(), MB_OK);
return;
Expand Down Expand Up @@ -541,28 +553,40 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() {
// for simplicity
play_thread_ = std::make_unique<std::thread>(StartPlayback, SherpaOnnxOfflineTtsSampleRate(tts_));

generate_btn_.EnableWindow(FALSE);

const SherpaOnnxGeneratedAudio *audio =
SherpaOnnxOfflineTtsGenerateWithCallback(tts_, ss.c_str(), speaker_id, speed, &AudioGeneratedCallback);

generate_btn_.EnableWindow(TRUE);
if (generate_thread_ && generate_thread_->joinable()) {
generate_thread_->join();
}

output_filename_.GetWindowText(s);
std::string filename = ToString(s);

int ok = SherpaOnnxWriteWave(audio->samples, audio->n, audio->sample_rate,
filename.c_str());
generate_thread_ = std::make_unique<std::thread>([ss, this,filename, speaker_id, speed]() {
std::string text = ss;

SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio);
// generate_btn_.EnableWindow(FALSE);

if (ok) {
// AfxMessageBox(Utf8ToUtf16(std::string("Saved to ") + filename + " successfully").c_str(), MB_OK);
AppendLineToMultilineEditCtrl(my_hint_, std::string("Saved to ") + filename + " successfully");
} else {
// AfxMessageBox(Utf8ToUtf16(std::string("Failed to save to ") + filename).c_str(), MB_OK);
AppendLineToMultilineEditCtrl(my_hint_, std::string("Failed to saved to ") + filename);
}
const SherpaOnnxGeneratedAudio *audio =
SherpaOnnxOfflineTtsGenerateWithCallback(tts_, text.c_str(), speaker_id, speed, &AudioGeneratedCallback);
// generate_btn_.EnableWindow(TRUE);
g_stopped = true;

int ok = SherpaOnnxWriteWave(audio->samples, audio->n, audio->sample_rate,
filename.c_str());

SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio);

if (ok) {
// AfxMessageBox(Utf8ToUtf16(std::string("Saved to ") + filename + " successfully").c_str(), MB_OK);

// AppendLineToMultilineEditCtrl(my_hint_, std::string("Saved to ") + filename + " successfully");
} else {
// AfxMessageBox(Utf8ToUtf16(std::string("Failed to save to ") + filename).c_str(), MB_OK);

// AppendLineToMultilineEditCtrl(my_hint_, std::string("Failed to saved to ") + filename);
}
});

//CDialogEx::OnOK();
}

void CNonStreamingTextToSpeechDlg::OnBnClickedStop() { g_killed = true; }
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,8 @@ class CNonStreamingTextToSpeechDlg : public CDialogEx
private:
Microphone mic_;
std::unique_ptr<std::thread> play_thread_;
std::unique_ptr<std::thread> generate_thread_;

public:
afx_msg void OnBnClickedStop();
};
Loading
Loading