Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ltorje-8x8 committed Feb 16, 2024
1 parent afc63d3 commit 8979044
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 66 deletions.
14 changes: 14 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
<!-- Match jicoco's jetty version. -->
<jicoco.version>1.1-133-g768ef2e</jicoco.version>
<jetty.version>11.0.14</jetty.version>
<mockito-core.version>5.10.0</mockito-core.version>
<assertj-core.version>3.25.3</assertj-core.version>
</properties>

<dependencies>
Expand Down Expand Up @@ -434,6 +436,18 @@
<artifactId>xmlunit</artifactId>
<version>1.6</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>${mockito-core.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>${assertj-core.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jitsi</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ public void configureTranslationManager()
*
* @return a {@code String}
*/
private String getDebugName()
String getDebugName()
{
return roomName;
}
Expand Down
116 changes: 51 additions & 65 deletions src/main/java/org/jitsi/jigasi/transcription/WhisperWebsocket.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,23 @@
*/
package org.jitsi.jigasi.transcription;

import io.jsonwebtoken.JwtBuilder;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketClose;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketError;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.jitsi.jigasi.JigasiBundleActivator;
import org.jitsi.utils.logging.Logger;
import org.json.JSONObject;

import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.time.Duration;
import java.time.Instant;
import io.jsonwebtoken.*;
import org.eclipse.jetty.websocket.api.*;
import org.eclipse.jetty.websocket.api.annotations.*;
import org.eclipse.jetty.websocket.client.*;
import org.jitsi.jigasi.*;
import org.jitsi.utils.logging.*;
import org.json.*;

import java.io.*;
import java.net.*;
import java.nio.*;
import java.security.*;
import java.security.spec.*;
import java.time.*;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.*;
import java.util.function.Supplier;


@WebSocket
Expand All @@ -60,7 +50,7 @@ public class WhisperWebsocket


/* Transcription language requested by the user who requested the transcription */
private String transcriptionTag = "en-US";
public String transcriptionTag = "en-US";

private static final Logger logger = Logger.getLogger(WhisperWebsocket.class);

Expand Down Expand Up @@ -232,48 +222,44 @@ public void onClose(int statusCode, String reason)
@OnWebSocketMessage
public void onMessage(String msg)
{
final JSONObject obj = new JSONObject(msg);
final String msgType = obj.getString("type");
final String participantId = obj.getString("participant_id");
final Instant transcriptionStart = Instant.ofEpochMilli(obj.getLong("ts"));
final Participant participant = participants.get(participantId);
String result;
JSONObject obj = new JSONObject(msg);
String msgType = obj.getString("type");
String participantId = obj.getString("participant_id");
Participant participant = participants.get(participantId);
final boolean isInterim = !"final".equals(msgType);
final UUID messageID = UUID.fromString(obj.getString("id"));
final String result = obj.getString("text");


final float stability = obj.getFloat("variance");

result = obj.getString("text");
UUID messageId = UUID.fromString(obj.getString("id"));
Instant transcriptionStart = Instant.ofEpochMilli(obj.getLong("ts"));
float stability = obj.getFloat("variance");
if (logger.isDebugEnabled())
{
logger.debug("Received final: " + result);
}

final Set<TranscriptionListener> partListeners = participantListeners.getOrDefault(participantId, null);
if (result.isEmpty() || partListeners == null)
Set<TranscriptionListener> partListeners = participantListeners.getOrDefault(participantId, null);
if (!result.isEmpty() && partListeners != null)
{
return;
}
int i=0;

int i = 0;
for (final TranscriptionListener transcriptionListener : partListeners)
{
i++;
if (logger.isDebugEnabled())
for (TranscriptionListener l : partListeners)
{
logger.debug("ParticipantId: " + i + ", " + participantId);
logger.debug("TranscriptionListener: " + transcriptionListener.toString());
i++;
if (logger.isDebugEnabled())
{
logger.debug("ParticipantId: " + i + ", " + participantId);
logger.debug("TranscriptionListener: " + l.toString());
}
TranscriptionResult tsResult = new TranscriptionResult(
participant,
messageId,
transcriptionStart,
isInterim,
getLanguage(participant),
stability,
new TranscriptionAlternative(result));
l.notify(tsResult);
}

final TranscriptionResult transcriptionResult = new TranscriptionResult(
participant,
messageID,
transcriptionStart,
isInterim,
getLanguage(participant),
stability,
new TranscriptionAlternative(result));
transcriptionListener.notify(transcriptionResult);
}
}

Expand Down Expand Up @@ -305,13 +291,13 @@ private String getLanguage(Participant participant)
return lang;
}

private ByteBuffer buildPayload(String participantId, Participant participant, ByteBuffer audio)
ByteBuffer buildPayload(String participantId, Participant participant, ByteBuffer audio, Supplier<Long> timestampSupplier)
{
final ByteBuffer header = ByteBuffer.allocate(60);
final int lenAudio = audio.remaining();
final ByteBuffer fullPayload = ByteBuffer.allocate(lenAudio + 60);
ByteBuffer header = ByteBuffer.allocate(60);
int lenAudio = audio.remaining();
ByteBuffer fullPayload = ByteBuffer.allocate(lenAudio + 60);
final String headerStr = participantId + VALUES_DELIMITER +
Instant.now().toEpochMilli() + VALUES_DELIMITER +
timestampSupplier.get() + VALUES_DELIMITER +
this.getLanguage(participant);

if (logger.isDebugEnabled())
Expand Down Expand Up @@ -375,7 +361,7 @@ public void sendAudio(String participantId, Participant participant, ByteBuffer
{
try
{
remoteEndpoint.sendBytes(buildPayload(participantId, participant, audio));
remoteEndpoint.sendBytes(buildPayload(participantId, participant, audio, () -> Instant.now().toEpochMilli()));
}
catch (IOException e)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package org.jitsi.jigasi.transcription;

import org.jitsi.jigasi.JigasiBundleActivator;
import org.junit.jupiter.api.Test;

import java.nio.ByteBuffer;

import static org.assertj.core.api.Assertions.assertThat;
import static org.jitsi.jigasi.transcription.Transcriber.FILTER_SILENCE_DEFAULT_VALUE;
import static org.jitsi.jigasi.transcription.Transcriber.P_NAME_FILTER_SILENCE;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

class WhisperWebsocketTest {

private static class TranscriptionListenerForTest implements TranscriptionListener {

private TranscriptionResult result;

public TranscriptionResult getResult() {
return result;
}

@Override
public void notify(TranscriptionResult result) {
this.result = result;
}

@Override
public void completed() {

}

@Override
public void failed(FailureReason reason) {

}
}

@Test
void onMessage() {
final WhisperWebsocket whisperWebsocket = new WhisperWebsocket();
final TranscriptionListenerForTest listener = new TranscriptionListenerForTest();
final Participant participant = mock(Participant.class);
when(participant.getDebugName()).thenReturn("room/id1234567890");

whisperWebsocket.addListener(listener, participant);
whisperWebsocket.onMessage("{\"type\":\"final\",\"participant_id\":\"id1234567890\", \"ts\":\"3457658454\", \"id\":\"01870603-f211-7b9a-a7ea-4a98f5320ff8\", \"text\":\"hello world\", \"variance\":0.9}");

assertThat(listener.getResult().getAlternatives().stream().findFirst().orElse(null).getTranscription()).isEqualTo("hello world");
}

@Test
void buildPayload() {
final WhisperWebsocket whisperWebsocket = new WhisperWebsocket();
final Participant participant = mock(Participant.class);
when(participant.getTranslationLanguage()).thenReturn("eng ");

final ByteBuffer buffer =
whisperWebsocket.buildPayload("id1234567".repeat(5), participant, ByteBuffer.wrap("hello world!".getBytes()), () -> 123456789L);
assertThat(buffer).isNotNull();
assertThat(new String(buffer.array())).isEqualTo("id1234567id1234567id1234567id1234567id1234567|123456789|eng hello world!");
}
}

0 comments on commit 8979044

Please sign in to comment.