diff --git a/clients/aws-sdk-transcribe-streaming/tests/integration/__init__.py b/clients/aws-sdk-transcribe-streaming/tests/integration/__init__.py new file mode 100644 index 0000000..d5b39d3 --- /dev/null +++ b/clients/aws-sdk-transcribe-streaming/tests/integration/__init__.py @@ -0,0 +1,22 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +from smithy_aws_core.identity import EnvironmentCredentialsResolver + +from aws_sdk_transcribe_streaming.client import TranscribeStreamingClient +from aws_sdk_transcribe_streaming.config import Config + +AUDIO_FILE = Path(__file__).parent / "assets" / "test.wav" + + +def create_transcribe_client(region: str) -> TranscribeStreamingClient: + """Helper to create a TranscribeStreamingClient for a given region.""" + return TranscribeStreamingClient( + config=Config( + endpoint_uri=f"https://transcribestreaming.{region}.amazonaws.com", + region=region, + aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), + ) + ) diff --git a/clients/aws-sdk-transcribe-streaming/tests/integration/assets/test.wav b/clients/aws-sdk-transcribe-streaming/tests/integration/assets/test.wav new file mode 100644 index 0000000..5f6ca02 Binary files /dev/null and b/clients/aws-sdk-transcribe-streaming/tests/integration/assets/test.wav differ diff --git a/clients/aws-sdk-transcribe-streaming/tests/integration/test_bidirectional_streaming.py b/clients/aws-sdk-transcribe-streaming/tests/integration/test_bidirectional_streaming.py new file mode 100644 index 0000000..d4582e5 --- /dev/null +++ b/clients/aws-sdk-transcribe-streaming/tests/integration/test_bidirectional_streaming.py @@ -0,0 +1,112 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Test bidirectional event stream handling.""" + +import asyncio +import time + +from smithy_core.aio.eventstream import DuplexEventStream + +from aws_sdk_transcribe_streaming.models import ( + AudioEvent, + AudioStream, + AudioStreamAudioEvent, + LanguageCode, + MediaEncoding, + StartStreamTranscriptionInput, + StartStreamTranscriptionOutput, + TranscriptResultStream, + TranscriptResultStreamTranscriptEvent, +) + +from . import AUDIO_FILE, create_transcribe_client + + +SAMPLE_RATE = 16000 +BYTES_PER_SAMPLE = 2 +CHANNEL_NUMS = 1 +CHUNK_SIZE = 1024 * 8 + + +async def _send_audio_chunks( + stream: DuplexEventStream[ + AudioStream, TranscriptResultStream, StartStreamTranscriptionOutput + ], +) -> None: + """Send audio chunks from file simulating real-time delay.""" + start_time = time.time() + elapsed_audio_time = 0.0 + + with AUDIO_FILE.open("rb") as f: + while chunk := f.read(CHUNK_SIZE): + await stream.input_stream.send( + AudioStreamAudioEvent(value=AudioEvent(audio_chunk=chunk)) + ) + elapsed_audio_time += len(chunk) / ( + BYTES_PER_SAMPLE * SAMPLE_RATE * CHANNEL_NUMS + ) + wait_time = start_time + elapsed_audio_time - time.time() + await asyncio.sleep(wait_time) + + # Send an empty audio event to signal end of input + await stream.input_stream.send( + AudioStreamAudioEvent(value=AudioEvent(audio_chunk=b"")) + ) + await asyncio.sleep(0.4) + await stream.input_stream.close() + + +async def _receive_transcription_output( + stream: DuplexEventStream[ + AudioStream, TranscriptResultStream, StartStreamTranscriptionOutput + ], +) -> tuple[bool, list[str]]: + """Receive and collect transcription output from the stream. + + Returns: + Tuple of (got_transcript_events, transcripts) + """ + got_transcript_events = False + transcripts: list[str] = [] + + _, output_stream = await stream.await_output() + if output_stream is None: + return got_transcript_events, transcripts + + async for event in output_stream: + if not isinstance(event, TranscriptResultStreamTranscriptEvent): + raise RuntimeError( + f"Received unexpected event type in stream: {type(event).__name__}" + ) + + got_transcript_events = True + if event.value.transcript and event.value.transcript.results: + for result in event.value.transcript.results: + if result.alternatives: + for alt in result.alternatives: + if alt.transcript: + transcripts.append(alt.transcript) + + return got_transcript_events, transcripts + + +async def test_start_stream_transcription() -> None: + """Test bidirectional streaming with audio input and transcription output.""" + transcribe_client = create_transcribe_client("us-west-2") + + stream = await transcribe_client.start_stream_transcription( + input=StartStreamTranscriptionInput( + language_code=LanguageCode.EN_US, + media_sample_rate_hertz=SAMPLE_RATE, + media_encoding=MediaEncoding.PCM, + ) + ) + + results = await asyncio.gather( + _send_audio_chunks(stream), _receive_transcription_output(stream) + ) + got_transcript_events, transcripts = results[1] + + assert got_transcript_events, "Expected to receive transcript events" + assert len(transcripts) > 0, "Expected to receive at least one transcript" diff --git a/clients/aws-sdk-transcribe-streaming/tests/integration/test_non_streaming.py b/clients/aws-sdk-transcribe-streaming/tests/integration/test_non_streaming.py new file mode 100644 index 0000000..80137bd --- /dev/null +++ b/clients/aws-sdk-transcribe-streaming/tests/integration/test_non_streaming.py @@ -0,0 +1,123 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Test non-streaming output type handling. + +This test requires AWS resources (an IAM role and an S3 bucket). +To set them up locally, run: + + uv run scripts/setup_resources.py + +Then export the environment variables shown in the output. +""" + +import asyncio +import os +import time +import uuid + +import pytest + +from aws_sdk_transcribe_streaming.models import ( + ClinicalNoteGenerationSettings, + GetMedicalScribeStreamInput, + GetMedicalScribeStreamOutput, + LanguageCode, + MedicalScribeAudioEvent, + MedicalScribeConfigurationEvent, + MedicalScribeInputStreamAudioEvent, + MedicalScribeInputStreamConfigurationEvent, + MedicalScribeInputStreamSessionControlEvent, + MedicalScribePostStreamAnalyticsSettings, + MedicalScribeSessionControlEvent, + MedicalScribeSessionControlEventType, + MediaEncoding, + StartMedicalScribeStreamInput, +) + +from . import AUDIO_FILE, create_transcribe_client + +SAMPLE_RATE = 16000 +BYTES_PER_SAMPLE = 2 +CHANNEL_NUMS = 1 +CHUNK_SIZE = 1024 * 8 + + +async def test_get_medical_scribe_stream() -> None: + role_arn = os.environ.get("HEALTHSCRIBE_ROLE_ARN") + s3_bucket = os.environ.get("HEALTHSCRIBE_S3_BUCKET") + + if not role_arn or not s3_bucket: + pytest.fail("HEALTHSCRIBE_ROLE_ARN or HEALTHSCRIBE_S3_BUCKET not set") + + transcribe_client = create_transcribe_client("us-east-1") + session_id = str(uuid.uuid4()) + + stream = await transcribe_client.start_medical_scribe_stream( + input=StartMedicalScribeStreamInput( + language_code=LanguageCode.EN_US, + media_sample_rate_hertz=SAMPLE_RATE, + media_encoding=MediaEncoding.PCM, + session_id=session_id, + ) + ) + + await stream.input_stream.send( + MedicalScribeInputStreamConfigurationEvent( + value=MedicalScribeConfigurationEvent( + resource_access_role_arn=role_arn, + post_stream_analytics_settings=MedicalScribePostStreamAnalyticsSettings( + clinical_note_generation_settings=ClinicalNoteGenerationSettings( + output_bucket_name=s3_bucket + ) + ), + ) + ) + ) + + start_time = time.time() + elapsed_audio_time = 0.0 + + with AUDIO_FILE.open("rb") as f: + while chunk := f.read(CHUNK_SIZE): + await stream.input_stream.send( + MedicalScribeInputStreamAudioEvent( + value=MedicalScribeAudioEvent(audio_chunk=chunk) + ) + ) + elapsed_audio_time += len(chunk) / ( + BYTES_PER_SAMPLE * SAMPLE_RATE * CHANNEL_NUMS + ) + wait_time = start_time + elapsed_audio_time - time.time() + if wait_time > 0: + await asyncio.sleep(wait_time) + + await stream.input_stream.send( + MedicalScribeInputStreamSessionControlEvent( + value=MedicalScribeSessionControlEvent( + type=MedicalScribeSessionControlEventType.END_OF_SESSION + ) + ) + ) + await stream.input_stream.close() + + await stream.await_output() + + # Consume output stream events to properly close the connection + if stream.output_stream: + async for _ in stream.output_stream: + pass + + response = await transcribe_client.get_medical_scribe_stream( + input=GetMedicalScribeStreamInput(session_id=session_id) + ) + + assert isinstance(response, GetMedicalScribeStreamOutput) + assert response.medical_scribe_stream_details is not None + + details = response.medical_scribe_stream_details + assert details.session_id == session_id + assert details.stream_status == "COMPLETED" + assert details.language_code == "en-US" + assert details.media_encoding == "pcm" + assert details.media_sample_rate_hertz == SAMPLE_RATE diff --git a/clients/aws-sdk-transcribe-streaming/tests/setup_resources.py b/clients/aws-sdk-transcribe-streaming/tests/setup_resources.py new file mode 100644 index 0000000..35c05f7 --- /dev/null +++ b/clients/aws-sdk-transcribe-streaming/tests/setup_resources.py @@ -0,0 +1,97 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "boto3", +# ] +# /// +# +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Setup script to create AWS resources for integration tests. + +Creates an IAM role and S3 bucket needed for medical scribe integration tests. + +Note: + This script is intended for local testing only and should not be used for + production setups. + +Usage: + uv run scripts/setup_resources.py +""" + +import json +from typing import Any + +import boto3 + + +def create_iam_role(iam_client: Any, role_name: str, bucket_name: str) -> None: + trust_policy = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": { + "Service": [ + "transcribe.streaming.amazonaws.com" + ] + }, + "Action": "sts:AssumeRole", + } + ] + } + + try: + iam_client.create_role( + RoleName=role_name, AssumeRolePolicyDocument=json.dumps(trust_policy) + ) + except iam_client.exceptions.EntityAlreadyExistsException: + pass + + permissions_policy = { + "Version": "2012-10-17", + "Statement": [ + { + "Action": [ + "s3:PutObject" + ], + "Resource": [ + f"arn:aws:s3:::{bucket_name}", + f"arn:aws:s3:::{bucket_name}/*", + ], + "Effect": "Allow" + } + ] + } + + iam_client.put_role_policy( + RoleName=role_name, + PolicyName="HealthScribeS3Access", + PolicyDocument=json.dumps(permissions_policy), + ) + + +def setup_healthscribe_resources() -> tuple[str, str]: + region = "us-east-1" + iam = boto3.client("iam") + s3 = boto3.client("s3", region_name=region) + sts = boto3.client("sts") + + account_id = sts.get_caller_identity()["Account"] + bucket_name = f"healthscribe-test-{account_id}-{region}" + role_name = "HealthScribeIntegrationTestRole" + + s3.create_bucket(Bucket=bucket_name) + create_iam_role(iam, role_name, bucket_name) + + role_arn = f"arn:aws:iam::{account_id}:role/{role_name}" + return role_arn, bucket_name + + +if __name__ == "__main__": + role_arn, bucket_name = setup_healthscribe_resources() + + print("Setup complete. Export these environment variables before running tests:") + print(f"export HEALTHSCRIBE_ROLE_ARN={role_arn}") + print(f"export HEALTHSCRIBE_S3_BUCKET={bucket_name}")