Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path

from smithy_aws_core.identity import EnvironmentCredentialsResolver

from aws_sdk_transcribe_streaming.client import TranscribeStreamingClient
from aws_sdk_transcribe_streaming.config import Config

AUDIO_FILE = Path(__file__).parent / "assets" / "test.wav"


def create_transcribe_client(region: str) -> TranscribeStreamingClient:
"""Helper to create a TranscribeStreamingClient for a given region."""
return TranscribeStreamingClient(
config=Config(
endpoint_uri=f"https://transcribestreaming.{region}.amazonaws.com",
region=region,
aws_credentials_identity_resolver=EnvironmentCredentialsResolver(),
)
)
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""Test bidirectional event stream handling."""

import asyncio
import time

from smithy_core.aio.eventstream import DuplexEventStream

from aws_sdk_transcribe_streaming.models import (
AudioEvent,
AudioStream,
AudioStreamAudioEvent,
LanguageCode,
MediaEncoding,
StartStreamTranscriptionInput,
StartStreamTranscriptionOutput,
TranscriptResultStream,
TranscriptResultStreamTranscriptEvent,
)

from . import AUDIO_FILE, create_transcribe_client


SAMPLE_RATE = 16000
BYTES_PER_SAMPLE = 2
CHANNEL_NUMS = 1
CHUNK_SIZE = 1024 * 8


async def _send_audio_chunks(
stream: DuplexEventStream[
AudioStream, TranscriptResultStream, StartStreamTranscriptionOutput
],
) -> None:
"""Send audio chunks from file simulating real-time delay."""
start_time = time.time()
elapsed_audio_time = 0.0

with AUDIO_FILE.open("rb") as f:
while chunk := f.read(CHUNK_SIZE):
await stream.input_stream.send(
AudioStreamAudioEvent(value=AudioEvent(audio_chunk=chunk))
)
elapsed_audio_time += len(chunk) / (
BYTES_PER_SAMPLE * SAMPLE_RATE * CHANNEL_NUMS
)
wait_time = start_time + elapsed_audio_time - time.time()
await asyncio.sleep(wait_time)

# Send an empty audio event to signal end of input
await stream.input_stream.send(
AudioStreamAudioEvent(value=AudioEvent(audio_chunk=b""))
)
await asyncio.sleep(0.4)
await stream.input_stream.close()


async def _receive_transcription_output(
stream: DuplexEventStream[
AudioStream, TranscriptResultStream, StartStreamTranscriptionOutput
],
) -> tuple[bool, list[str]]:
"""Receive and collect transcription output from the stream.

Returns:
Tuple of (got_transcript_events, transcripts)
"""
got_transcript_events = False
transcripts: list[str] = []

_, output_stream = await stream.await_output()
if output_stream is None:
return got_transcript_events, transcripts

async for event in output_stream:
if not isinstance(event, TranscriptResultStreamTranscriptEvent):
raise RuntimeError(
f"Received unexpected event type in stream: {type(event).__name__}"
)

got_transcript_events = True
if event.value.transcript and event.value.transcript.results:
for result in event.value.transcript.results:
if result.alternatives:
for alt in result.alternatives:
if alt.transcript:
transcripts.append(alt.transcript)

return got_transcript_events, transcripts


async def test_start_stream_transcription() -> None:
"""Test bidirectional streaming with audio input and transcription output."""
transcribe_client = create_transcribe_client("us-west-2")

stream = await transcribe_client.start_stream_transcription(
input=StartStreamTranscriptionInput(
language_code=LanguageCode.EN_US,
media_sample_rate_hertz=SAMPLE_RATE,
media_encoding=MediaEncoding.PCM,
)
)

results = await asyncio.gather(
_send_audio_chunks(stream), _receive_transcription_output(stream)
)
got_transcript_events, transcripts = results[1]

assert got_transcript_events, "Expected to receive transcript events"
assert len(transcripts) > 0, "Expected to receive at least one transcript"
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""Test non-streaming output type handling.

This test requires AWS resources (an IAM role and an S3 bucket).
To set them up locally, run:

uv run scripts/setup_resources.py

Then export the environment variables shown in the output.
"""

import asyncio
import os
import time
import uuid

import pytest

from aws_sdk_transcribe_streaming.models import (
ClinicalNoteGenerationSettings,
GetMedicalScribeStreamInput,
GetMedicalScribeStreamOutput,
LanguageCode,
MedicalScribeAudioEvent,
MedicalScribeConfigurationEvent,
MedicalScribeInputStreamAudioEvent,
MedicalScribeInputStreamConfigurationEvent,
MedicalScribeInputStreamSessionControlEvent,
MedicalScribePostStreamAnalyticsSettings,
MedicalScribeSessionControlEvent,
MedicalScribeSessionControlEventType,
MediaEncoding,
StartMedicalScribeStreamInput,
)

from . import AUDIO_FILE, create_transcribe_client

SAMPLE_RATE = 16000
BYTES_PER_SAMPLE = 2
CHANNEL_NUMS = 1
CHUNK_SIZE = 1024 * 8


async def test_get_medical_scribe_stream() -> None:
role_arn = os.environ.get("HEALTHSCRIBE_ROLE_ARN")
s3_bucket = os.environ.get("HEALTHSCRIBE_S3_BUCKET")

if not role_arn or not s3_bucket:
pytest.fail("HEALTHSCRIBE_ROLE_ARN or HEALTHSCRIBE_S3_BUCKET not set")

transcribe_client = create_transcribe_client("us-east-1")
session_id = str(uuid.uuid4())

stream = await transcribe_client.start_medical_scribe_stream(
input=StartMedicalScribeStreamInput(
language_code=LanguageCode.EN_US,
media_sample_rate_hertz=SAMPLE_RATE,
media_encoding=MediaEncoding.PCM,
session_id=session_id,
)
)

await stream.input_stream.send(
MedicalScribeInputStreamConfigurationEvent(
value=MedicalScribeConfigurationEvent(
resource_access_role_arn=role_arn,
post_stream_analytics_settings=MedicalScribePostStreamAnalyticsSettings(
clinical_note_generation_settings=ClinicalNoteGenerationSettings(
output_bucket_name=s3_bucket
)
),
)
)
)

start_time = time.time()
elapsed_audio_time = 0.0

with AUDIO_FILE.open("rb") as f:
while chunk := f.read(CHUNK_SIZE):
await stream.input_stream.send(
MedicalScribeInputStreamAudioEvent(
value=MedicalScribeAudioEvent(audio_chunk=chunk)
)
)
elapsed_audio_time += len(chunk) / (
BYTES_PER_SAMPLE * SAMPLE_RATE * CHANNEL_NUMS
)
wait_time = start_time + elapsed_audio_time - time.time()
if wait_time > 0:
await asyncio.sleep(wait_time)

await stream.input_stream.send(
MedicalScribeInputStreamSessionControlEvent(
value=MedicalScribeSessionControlEvent(
type=MedicalScribeSessionControlEventType.END_OF_SESSION
)
)
)
await stream.input_stream.close()

await stream.await_output()

# Consume output stream events to properly close the connection
if stream.output_stream:
async for _ in stream.output_stream:
pass

response = await transcribe_client.get_medical_scribe_stream(
input=GetMedicalScribeStreamInput(session_id=session_id)
)

assert isinstance(response, GetMedicalScribeStreamOutput)
assert response.medical_scribe_stream_details is not None

details = response.medical_scribe_stream_details
assert details.session_id == session_id
assert details.stream_status == "COMPLETED"
assert details.language_code == "en-US"
assert details.media_encoding == "pcm"
assert details.media_sample_rate_hertz == SAMPLE_RATE
97 changes: 97 additions & 0 deletions clients/aws-sdk-transcribe-streaming/tests/setup_resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "boto3",
# ]
# ///
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""Setup script to create AWS resources for integration tests.

Creates an IAM role and S3 bucket needed for medical scribe integration tests.

Note:
This script is intended for local testing only and should not be used for
production setups.

Usage:
uv run scripts/setup_resources.py
"""

import json
from typing import Any

import boto3


def create_iam_role(iam_client: Any, role_name: str, bucket_name: str) -> None:
trust_policy = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {
"Service": [
"transcribe.streaming.amazonaws.com"
]
},
"Action": "sts:AssumeRole",
}
]
}

try:
iam_client.create_role(
RoleName=role_name, AssumeRolePolicyDocument=json.dumps(trust_policy)
)
except iam_client.exceptions.EntityAlreadyExistsException:
pass

permissions_policy = {
"Version": "2012-10-17",
"Statement": [
{
"Action": [
"s3:PutObject"
],
"Resource": [
f"arn:aws:s3:::{bucket_name}",
f"arn:aws:s3:::{bucket_name}/*",
],
"Effect": "Allow"
}
]
}

iam_client.put_role_policy(
RoleName=role_name,
PolicyName="HealthScribeS3Access",
PolicyDocument=json.dumps(permissions_policy),
)


def setup_healthscribe_resources() -> tuple[str, str]:
region = "us-east-1"
iam = boto3.client("iam")
s3 = boto3.client("s3", region_name=region)
sts = boto3.client("sts")

account_id = sts.get_caller_identity()["Account"]
bucket_name = f"healthscribe-test-{account_id}-{region}"
role_name = "HealthScribeIntegrationTestRole"

s3.create_bucket(Bucket=bucket_name)
create_iam_role(iam, role_name, bucket_name)

role_arn = f"arn:aws:iam::{account_id}:role/{role_name}"
return role_arn, bucket_name


if __name__ == "__main__":
role_arn, bucket_name = setup_healthscribe_resources()

print("Setup complete. Export these environment variables before running tests:")
print(f"export HEALTHSCRIBE_ROLE_ARN={role_arn}")
print(f"export HEALTHSCRIBE_S3_BUCKET={bucket_name}")