Skip to content

Commit

Permalink
Add /api/speech/token endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
cecheta committed Apr 15, 2024
1 parent 0a7f180 commit 2119904
Show file tree
Hide file tree
Showing 13 changed files with 553 additions and 283 deletions.
9 changes: 9 additions & 0 deletions code/backend/batch/utilities/helpers/EnvHelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def __init__(self, **kwargs) -> None:

self.LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper()

# Azure
self.AZURE_SUBSCRIPTION_ID = os.getenv("AZURE_SUBSCRIPTION_ID", "")
self.AZURE_RESOURCE_GROUP = os.getenv("AZURE_RESOURCE_GROUP", "")

# Azure Search
self.AZURE_SEARCH_SERVICE = os.getenv("AZURE_SEARCH_SERVICE", "")
self.AZURE_SEARCH_INDEX = os.getenv("AZURE_SEARCH_INDEX", "")
Expand Down Expand Up @@ -159,7 +163,12 @@ def __init__(self, **kwargs) -> None:
"ORCHESTRATION_STRATEGY", "openai_function"
)
# Speech Service
self.AZURE_SPEECH_SERVICE_NAME = os.getenv("AZURE_SPEECH_SERVICE_NAME", "")
self.AZURE_SPEECH_SERVICE_REGION = os.getenv("AZURE_SPEECH_SERVICE_REGION")
self.AZURE_SPEECH_REGION_ENDPOINT = os.environ.get(
"AZURE_SPEECH_REGION_ENDPOINT",
f"https://{self.AZURE_SPEECH_SERVICE_REGION}.api.cognitive.microsoft.com/",
)

self.LOAD_CONFIG_FROM_BLOB_STORAGE = self.get_env_var_bool(
"LOAD_CONFIG_FROM_BLOB_STORAGE"
Expand Down
47 changes: 46 additions & 1 deletion code/create_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from flask import Flask, Response, request, jsonify
from dotenv import load_dotenv
import sys
import functools
from backend.batch.utilities.helpers.EnvHelper import EnvHelper

from azure.mgmt.cognitiveservices import CognitiveServicesManagementClient
from azure.identity import DefaultAzureCredential

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -285,6 +287,24 @@ def conversation_without_data(request, env_helper):
)


@functools.cache
def get_speech_key(env_helper: EnvHelper):
"""
Get the Azure Speech key directly from Azure.
This is required to generate short-lived tokens when using RBAC.
"""
client = CognitiveServicesManagementClient(
credential=DefaultAzureCredential(),
subscription_id=env_helper.AZURE_SUBSCRIPTION_ID,
)
keys = client.accounts.list_keys(
resource_group_name=env_helper.AZURE_RESOURCE_GROUP,
account_name=env_helper.AZURE_SPEECH_SERVICE_NAME,
)

return keys.key1


def create_app():
# Fixing MIME types for static files under Windows
mimetypes.add_type("application/javascript", ".js")
Expand Down Expand Up @@ -381,4 +401,29 @@ def conversation_custom():
500,
)

@app.route("/api/speech/token", methods=["GET"])
def speech_token():
try:
speech_key = env_helper.AZURE_SPEECH_KEY or get_speech_key(env_helper)

response = requests.post(
f"{env_helper.AZURE_SPEECH_REGION_ENDPOINT}sts/v1.0/issueToken",
headers={
"Ocp-Apim-Subscription-Key": speech_key,
},
)

if response.status_code == 200:
return {
"token": response.text,
"region": env_helper.AZURE_SPEECH_SERVICE_REGION,
}

logger.error(f"Failed to get speech token: {response.text}")
return {"error": "Failed to get speech token"}, response.status_code
except Exception as e:
logger.exception(f"Exception in /api/speech/token | {str(e)}")

return {"error": "Failed to get speech token"}, 500

return app
94 changes: 38 additions & 56 deletions code/frontend/src/pages/chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import {
import {
SpeechRecognizer,
ResultReason,
AutoDetectSourceLanguageResult,
} from "microsoft-cognitiveservices-speech-sdk";
import ReactMarkdown from "react-markdown";
import remarkGfm from "remark-gfm";
Expand All @@ -22,7 +21,7 @@ import { multiLingualSpeechRecognizer } from "../../util/SpeechToText";
import {
ChatMessage,
ConversationRequest,
customConversationApi,
customConversationApi,
Citation,
ToolMessageContent,
ChatResponse,
Expand Down Expand Up @@ -56,8 +55,6 @@ const Chat = () => {
const [isRecognizing, setIsRecognizing] = useState(false);
const [isListening, setIsListening] = useState(false);
const recognizerRef = useRef<SpeechRecognizer | null>(null);
const [subscriptionKey, setSubscriptionKey] = useState<string>("");
const [serviceRegion, setServiceRegion] = useState<string>("");
const makeApiRequest = async (question: string) => {
lastQuestionRef.current = question;

Expand Down Expand Up @@ -110,7 +107,7 @@ const Chat = () => {
]);
}
runningText = "";
} catch {}
} catch { }
});
}
setAnswers([...answers, userMessage, ...result.choices[0].messages]);
Expand All @@ -134,62 +131,47 @@ const Chat = () => {
return abortController.abort();
};

useEffect(() => {
async function fetchServerConfig() {
try {
const response = await fetch("/api/config");
if (!response.ok) {
throw new Error("Network response was not ok");
}
const data = await response.json();
const fetchedSubscriptionKey = data.azureSpeechKey;
const fetchedServiceRegion = data.azureSpeechRegion;

setSubscriptionKey(fetchedSubscriptionKey);
setServiceRegion(fetchedServiceRegion);
} catch (error) {
console.error("Error fetching server configuration:", error);
const fetchSpeechToken = async (): Promise<{ token: string, region: string; }> => {
try {
const response = await fetch("/api/speech/token");
if (!response.ok) {
throw new Error("Network response was not ok");
}
return response.json();
} catch (error) {
console.error("Error fetching server configuration:", error);
throw error;
}
};

fetchServerConfig();
}, []);

const startSpeechRecognition = (
subscriptionKey: string,
serviceRegion: string
) => {
const startSpeechRecognition = async () => {
if (!isRecognizing) {
setIsRecognizing(true);

if (!subscriptionKey || !serviceRegion) {
console.error(
"Azure Speech subscription key or region is not defined."
);
} else {
const recognizer = multiLingualSpeechRecognizer(subscriptionKey, serviceRegion, [
"en-US",
"fr-FR",
"de-DE",
"it-IT"
]);
recognizerRef.current = recognizer; // Store the recognizer in the ref
const { token, region } = await fetchSpeechToken();

recognizerRef.current.recognized = (s, e) => {
if (e.result.reason === ResultReason.RecognizedSpeech) {
const recognized = e.result.text;
// console.log("Recognized:", recognized);
setUserMessage(recognized);
setRecognizedText(recognized);
}
};
const recognizer = multiLingualSpeechRecognizer(token, region, [
"en-US",
"fr-FR",
"de-DE",
"it-IT"
]);
recognizerRef.current = recognizer; // Store the recognizer in the ref

recognizerRef.current.startContinuousRecognitionAsync(() => {
setIsRecognizing(true);
// console.log("Speech recognition started.");
setIsListening(true);
});
}
recognizerRef.current.recognized = (s, e) => {
if (e.result.reason === ResultReason.RecognizedSpeech) {
const recognized = e.result.text;
// console.log("Recognized:", recognized);
setUserMessage(recognized);
setRecognizedText(recognized);
}
};

recognizerRef.current.startContinuousRecognitionAsync(() => {
setIsRecognizing(true);
// console.log("Speech recognition started.");
setIsListening(true);
});
}
};

Expand All @@ -208,10 +190,10 @@ const Chat = () => {
}
};

const onMicrophoneClick = () => {
const onMicrophoneClick = async () => {
if (!isRecognizing) {
// console.log("Starting speech recognition...");
startSpeechRecognition(subscriptionKey, serviceRegion);
await startSpeechRecognition();
} else {
// console.log("Stopping speech recognition...");
stopSpeechRecognition();
Expand Down Expand Up @@ -294,7 +276,7 @@ const Chat = () => {
answer.role === "assistant"
? answer.content
: "Sorry, an error occurred. Try refreshing the conversation or waiting a few minutes. If the issue persists, contact your system administrator. Error: " +
answer.content,
answer.content,
citations:
answer.role === "assistant"
? parseCitationFromMessage(answers[index - 1])
Expand Down
6 changes: 3 additions & 3 deletions code/frontend/src/util/SpeechToText.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ import {


export const multiLingualSpeechRecognizer = (
subscriptionKey: string,
token: string,
serviceRegion: string,
languages: string[]
) => {
const speechConfig = SpeechConfig.fromSubscription(
subscriptionKey,
const speechConfig = SpeechConfig.fromAuthorizationToken(
token,
serviceRegion
);
const audioConfig = AudioConfig.fromDefaultMicrophoneInput();
Expand Down
6 changes: 3 additions & 3 deletions code/frontend/test/SpeechToText.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import { multiLingualSpeechRecognizer } from "../src/util/SpeechToText.js";

describe("SpeechToText", () => {
it("creates a speech recognizer with multiple languages", async () => {
const key = "key";
const token = "token";
const region = "region";
const languages = ["en-US", "fr-FR", "de-DE", "it-IT"];
const recognizer = multiLingualSpeechRecognizer(key, region, languages);
const recognizer = multiLingualSpeechRecognizer(token, region, languages);

expect(recognizer.properties.getProperty("SpeechServiceConnection_Key")).to.equal(key);
expect(recognizer.authorizationToken).to.equal(token);
expect(recognizer.properties.getProperty("SpeechServiceConnection_Region")).to.equal(region);
expect(recognizer.properties.getProperty("SpeechServiceConnection_AutoDetectSourceLanguages")).to.equal(languages.join(","));
});
Expand Down
5 changes: 5 additions & 0 deletions code/tests/functional/backend_api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def setup_default_mocking(httpserver: HTTPServer, app_config: AppConfig):
}
)

httpserver.expect_request(
"/sts/v1.0/issueToken",
method="POST",
).respond_with_data("speech-token")

yield

httpserver.check()
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def app_config(make_httpserver, ca):
"AZURE_OPENAI_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"AZURE_SEARCH_SERVICE": f"https://localhost:{make_httpserver.port}/",
"AZURE_CONTENT_SAFETY_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"AZURE_SPEECH_REGION_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"SSL_CERT_FILE": ca_temp_path,
"CURL_CA_BUNDLE": ca_temp_path,
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
import requests
from pytest_httpserver import HTTPServer
from tests.functional.backend_api.app_config import AppConfig
from tests.functional.backend_api.request_matching import (
RequestMatcher,
verify_request_made,
)

pytestmark = pytest.mark.functional


def test_speech_token_returned(app_url: str, app_config: AppConfig):
# when
response = requests.get(f"{app_url}/api/speech/token")

# then
assert response.status_code == 200
assert response.json() == {
"token": "speech-token",
"region": app_config.get("AZURE_SPEECH_SERVICE_REGION"),
}
assert response.headers["Content-Type"] == "application/json"


def test_speech_service_called_correctly(
app_url: str, app_config: AppConfig, httpserver: HTTPServer
):
# when
requests.get(f"{app_url}/api/speech/token")

# then
verify_request_made(
mock_httpserver=httpserver,
request_matcher=RequestMatcher(
path="/sts/v1.0/issueToken",
method="POST",
headers={
"Ocp-Apim-Subscription-Key": app_config.get("AZURE_SPEECH_SERVICE_KEY")
},
times=1,
),
)


def test_failure_fetching_speech_token(app_url: str, httpserver: HTTPServer):
# given
httpserver.clear_all_handlers() # Clear default successful responses

httpserver.expect_request(
"/sts/v1.0/issueToken",
method="POST",
).respond_with_json({"error": "Bad request"}, status=400)

# when
response = requests.get(f"{app_url}/api/speech/token")

# then
assert response.status_code == 400
assert response.json() == {"error": "Failed to get speech token"}
assert response.headers["Content-Type"] == "application/json"
Loading

0 comments on commit 2119904

Please sign in to comment.