Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add /api/speech endpoint #680

Merged
merged 8 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions code/backend/batch/utilities/helpers/EnvHelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def __load_config(self, **kwargs) -> None:

self.LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper()

# Azure
self.AZURE_SUBSCRIPTION_ID = os.getenv("AZURE_SUBSCRIPTION_ID", "")
self.AZURE_RESOURCE_GROUP = os.getenv("AZURE_RESOURCE_GROUP", "")

# Azure Search
self.AZURE_SEARCH_SERVICE = os.getenv("AZURE_SEARCH_SERVICE", "")
self.AZURE_SEARCH_INDEX = os.getenv("AZURE_SEARCH_INDEX", "")
Expand Down Expand Up @@ -169,7 +173,12 @@ def __load_config(self, **kwargs) -> None:
"ORCHESTRATION_STRATEGY", "openai_function"
)
# Speech Service
self.AZURE_SPEECH_SERVICE_NAME = os.getenv("AZURE_SPEECH_SERVICE_NAME", "")
self.AZURE_SPEECH_SERVICE_REGION = os.getenv("AZURE_SPEECH_SERVICE_REGION")
self.AZURE_SPEECH_REGION_ENDPOINT = os.environ.get(
"AZURE_SPEECH_REGION_ENDPOINT",
f"https://{self.AZURE_SPEECH_SERVICE_REGION}.api.cognitive.microsoft.com/",
)

self.LOAD_CONFIG_FROM_BLOB_STORAGE = self.get_env_var_bool(
"LOAD_CONFIG_FROM_BLOB_STORAGE"
Expand Down
47 changes: 46 additions & 1 deletion code/create_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from flask import Flask, Response, request, jsonify
from dotenv import load_dotenv
import sys
import functools
from backend.batch.utilities.helpers.EnvHelper import EnvHelper

from azure.mgmt.cognitiveservices import CognitiveServicesManagementClient
from azure.identity import DefaultAzureCredential

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -285,6 +287,24 @@ def conversation_without_data(request, env_helper):
)


@functools.cache
ross-p-smith marked this conversation as resolved.
Show resolved Hide resolved
def get_speech_key(env_helper: EnvHelper):
"""
Get the Azure Speech key directly from Azure.
This is required to generate short-lived tokens when using RBAC.
"""
client = CognitiveServicesManagementClient(
credential=DefaultAzureCredential(),
subscription_id=env_helper.AZURE_SUBSCRIPTION_ID,
)
keys = client.accounts.list_keys(
adamdougal marked this conversation as resolved.
Show resolved Hide resolved
resource_group_name=env_helper.AZURE_RESOURCE_GROUP,
account_name=env_helper.AZURE_SPEECH_SERVICE_NAME,
)

return keys.key1


def create_app():
# Fixing MIME types for static files under Windows
mimetypes.add_type("application/javascript", ".js")
Expand Down Expand Up @@ -381,4 +401,29 @@ def conversation_custom():
500,
)

@app.route("/api/speech", methods=["GET"])
def speech_token():
try:
speech_key = env_helper.AZURE_SPEECH_KEY or get_speech_key(env_helper)

response = requests.post(
f"{env_helper.AZURE_SPEECH_REGION_ENDPOINT}sts/v1.0/issueToken",
headers={
"Ocp-Apim-Subscription-Key": speech_key,
},
)

if response.status_code == 200:
return {
"token": response.text,
"region": env_helper.AZURE_SPEECH_SERVICE_REGION,
}

logger.error(f"Failed to get speech token: {response.text}")
return {"error": "Failed to get speech token"}, response.status_code
except Exception as e:
logger.exception(f"Exception in /api/speech | {str(e)}")

return {"error": "Failed to get speech token"}, 500

return app
95 changes: 39 additions & 56 deletions code/frontend/src/pages/chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import {
import {
SpeechRecognizer,
ResultReason,
AutoDetectSourceLanguageResult,
} from "microsoft-cognitiveservices-speech-sdk";
import ReactMarkdown from "react-markdown";
import remarkGfm from "remark-gfm";
Expand All @@ -22,7 +21,7 @@ import { multiLingualSpeechRecognizer } from "../../util/SpeechToText";
import {
ChatMessage,
ConversationRequest,
customConversationApi,
customConversationApi,
Citation,
ToolMessageContent,
ChatResponse,
Expand Down Expand Up @@ -56,8 +55,6 @@ const Chat = () => {
const [isRecognizing, setIsRecognizing] = useState(false);
const [isListening, setIsListening] = useState(false);
const recognizerRef = useRef<SpeechRecognizer | null>(null);
const [subscriptionKey, setSubscriptionKey] = useState<string>("");
const [serviceRegion, setServiceRegion] = useState<string>("");
const makeApiRequest = async (question: string) => {
lastQuestionRef.current = question;

Expand Down Expand Up @@ -110,7 +107,7 @@ const Chat = () => {
]);
}
runningText = "";
} catch {}
} catch { }
});
}
setAnswers([...answers, userMessage, ...result.choices[0].messages]);
Expand All @@ -134,62 +131,48 @@ const Chat = () => {
return abortController.abort();
};

useEffect(() => {
async function fetchServerConfig() {
try {
const response = await fetch("/api/config");
if (!response.ok) {
throw new Error("Network response was not ok");
}
const data = await response.json();
const fetchedSubscriptionKey = data.azureSpeechKey;
const fetchedServiceRegion = data.azureSpeechRegion;

setSubscriptionKey(fetchedSubscriptionKey);
setServiceRegion(fetchedServiceRegion);
} catch (error) {
console.error("Error fetching server configuration:", error);
const fetchSpeechToken = async (): Promise<{ token: string, region: string; }> => {
try {
const response = await fetch("/api/speech");
if (!response.ok) {
console.error("Error fetching speech token:", response);
throw new Error("Network response was not ok");
adamdougal marked this conversation as resolved.
Show resolved Hide resolved
}
return response.json();
} catch (error) {
console.error("Error fetching server configuration:", error);
throw error;
}
};

fetchServerConfig();
}, []);

const startSpeechRecognition = (
subscriptionKey: string,
serviceRegion: string
) => {
const startSpeechRecognition = async () => {
if (!isRecognizing) {
setIsRecognizing(true);

if (!subscriptionKey || !serviceRegion) {
console.error(
"Azure Speech subscription key or region is not defined."
);
} else {
const recognizer = multiLingualSpeechRecognizer(subscriptionKey, serviceRegion, [
"en-US",
"fr-FR",
"de-DE",
"it-IT"
]);
recognizerRef.current = recognizer; // Store the recognizer in the ref
const { token, region } = await fetchSpeechToken();

recognizerRef.current.recognized = (s, e) => {
if (e.result.reason === ResultReason.RecognizedSpeech) {
const recognized = e.result.text;
// console.log("Recognized:", recognized);
setUserMessage(recognized);
setRecognizedText(recognized);
}
};
const recognizer = multiLingualSpeechRecognizer(token, region, [
"en-US",
"fr-FR",
"de-DE",
"it-IT"
]);
recognizerRef.current = recognizer; // Store the recognizer in the ref

recognizerRef.current.startContinuousRecognitionAsync(() => {
setIsRecognizing(true);
// console.log("Speech recognition started.");
setIsListening(true);
});
}
recognizerRef.current.recognized = (s, e) => {
if (e.result.reason === ResultReason.RecognizedSpeech) {
const recognized = e.result.text;
// console.log("Recognized:", recognized);
setUserMessage(recognized);
setRecognizedText(recognized);
}
};

recognizerRef.current.startContinuousRecognitionAsync(() => {
setIsRecognizing(true);
// console.log("Speech recognition started.");
setIsListening(true);
});
}
};

Expand All @@ -208,10 +191,10 @@ const Chat = () => {
}
};

const onMicrophoneClick = () => {
const onMicrophoneClick = async () => {
if (!isRecognizing) {
// console.log("Starting speech recognition...");
startSpeechRecognition(subscriptionKey, serviceRegion);
await startSpeechRecognition();
} else {
// console.log("Stopping speech recognition...");
stopSpeechRecognition();
Expand Down Expand Up @@ -294,7 +277,7 @@ const Chat = () => {
answer.role === "assistant"
? answer.content
: "Sorry, an error occurred. Try refreshing the conversation or waiting a few minutes. If the issue persists, contact your system administrator. Error: " +
answer.content,
answer.content,
citations:
answer.role === "assistant"
? parseCitationFromMessage(answers[index - 1])
Expand Down
6 changes: 3 additions & 3 deletions code/frontend/src/util/SpeechToText.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ import {


export const multiLingualSpeechRecognizer = (
subscriptionKey: string,
token: string,
serviceRegion: string,
languages: string[]
) => {
const speechConfig = SpeechConfig.fromSubscription(
subscriptionKey,
const speechConfig = SpeechConfig.fromAuthorizationToken(
token,
serviceRegion
);
const audioConfig = AudioConfig.fromDefaultMicrophoneInput();
Expand Down
6 changes: 3 additions & 3 deletions code/frontend/test/SpeechToText.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import { multiLingualSpeechRecognizer } from "../src/util/SpeechToText.js";

describe("SpeechToText", () => {
it("creates a speech recognizer with multiple languages", async () => {
const key = "key";
const token = "token";
const region = "region";
const languages = ["en-US", "fr-FR", "de-DE", "it-IT"];
const recognizer = multiLingualSpeechRecognizer(key, region, languages);
const recognizer = multiLingualSpeechRecognizer(token, region, languages);

expect(recognizer.properties.getProperty("SpeechServiceConnection_Key")).to.equal(key);
expect(recognizer.authorizationToken).to.equal(token);
expect(recognizer.properties.getProperty("SpeechServiceConnection_Region")).to.equal(region);
expect(recognizer.properties.getProperty("SpeechServiceConnection_AutoDetectSourceLanguages")).to.equal(languages.join(","));
});
Expand Down
5 changes: 5 additions & 0 deletions code/tests/functional/backend_api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def setup_default_mocking(httpserver: HTTPServer, app_config: AppConfig):
}
)

httpserver.expect_request(
"/sts/v1.0/issueToken",
method="POST",
).respond_with_data("speech-token")

yield

httpserver.check()
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def app_config(make_httpserver, ca):
"AZURE_OPENAI_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"AZURE_SEARCH_SERVICE": f"https://localhost:{make_httpserver.port}/",
"AZURE_CONTENT_SAFETY_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"AZURE_SPEECH_REGION_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"SSL_CERT_FILE": ca_temp_path,
"CURL_CA_BUNDLE": ca_temp_path,
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
import requests
from pytest_httpserver import HTTPServer
from tests.functional.backend_api.app_config import AppConfig
from tests.functional.backend_api.request_matching import (
RequestMatcher,
verify_request_made,
)

pytestmark = pytest.mark.functional


def test_speech_token_returned(app_url: str, app_config: AppConfig):
cecheta marked this conversation as resolved.
Show resolved Hide resolved
# when
response = requests.get(f"{app_url}/api/speech")

# then
assert response.status_code == 200
assert response.json() == {
"token": "speech-token",
"region": app_config.get("AZURE_SPEECH_SERVICE_REGION"),
}
assert response.headers["Content-Type"] == "application/json"


def test_speech_service_called_correctly(
app_url: str, app_config: AppConfig, httpserver: HTTPServer
):
# when
requests.get(f"{app_url}/api/speech")

# then
verify_request_made(
mock_httpserver=httpserver,
request_matcher=RequestMatcher(
path="/sts/v1.0/issueToken",
method="POST",
headers={
"Ocp-Apim-Subscription-Key": app_config.get("AZURE_SPEECH_SERVICE_KEY")
},
times=1,
),
)


def test_failure_fetching_speech_token(app_url: str, httpserver: HTTPServer):
# given
httpserver.clear_all_handlers() # Clear default successful responses

httpserver.expect_request(
"/sts/v1.0/issueToken",
method="POST",
).respond_with_json({"error": "Bad request"}, status=400)

# when
response = requests.get(f"{app_url}/api/speech")

# then
assert response.status_code == 400
assert response.json() == {"error": "Failed to get speech token"}
assert response.headers["Content-Type"] == "application/json"
Loading
Loading