diff --git a/sdk/python/agentfield/__init__.py b/sdk/python/agentfield/__init__.py index a146b35..839a7eb 100644 --- a/sdk/python/agentfield/__init__.py +++ b/sdk/python/agentfield/__init__.py @@ -30,6 +30,14 @@ FileOutput, detect_multimodal_response, ) +from .media_providers import ( + MediaProvider, + FalProvider, + LiteLLMProvider, + OpenRouterProvider, + get_provider, + register_provider, +) __all__ = [ "Agent", @@ -61,6 +69,13 @@ "ImageOutput", "FileOutput", "detect_multimodal_response", + # Media providers + "MediaProvider", + "FalProvider", + "LiteLLMProvider", + "OpenRouterProvider", + "get_provider", + "register_provider", ] __version__ = "0.1.31-rc.1" diff --git a/sdk/python/agentfield/agent.py b/sdk/python/agentfield/agent.py index dc6690c..2357bd1 100644 --- a/sdk/python/agentfield/agent.py +++ b/sdk/python/agentfield/agent.py @@ -2888,6 +2888,115 @@ async def ai_with_multimodal( # pragma: no cover - relies on external multimoda **kwargs, ) + async def ai_generate_image( # pragma: no cover - relies on external image services + self, + prompt: str, + model: Optional[str] = None, + size: str = "1024x1024", + quality: str = "standard", + style: Optional[str] = None, + response_format: str = "url", + **kwargs, + ) -> "MultimodalResponse": + """ + Generate an image from a text prompt. + + This is a dedicated method for image generation with a clearer name. + Returns a MultimodalResponse containing the generated image(s). + + Supported Providers: + - LiteLLM: DALL-E models like "dall-e-3", "dall-e-2" + - OpenRouter: Models like "openrouter/google/gemini-2.5-flash-image-preview" + + Args: + prompt (str): Text description of the image to generate. + model (str, optional): Model to use (defaults to AIConfig.vision_model). + size (str): Image dimensions (e.g., "1024x1024", "1792x1024"). + quality (str): Image quality ("standard" or "hd"). + style (str, optional): Image style for DALL-E 3 ("vivid" or "natural"). + response_format (str): Output format ("url" or "b64_json"). + **kwargs: Provider-specific parameters. + + Returns: + MultimodalResponse: Response with .images list containing ImageOutput objects. + + Example: + ```python + # Basic image generation + result = await app.ai_generate_image("A sunset over mountains") + if result.has_images: + result.images[0].save("sunset.png") + + # OpenRouter with Gemini + result = await app.ai_generate_image( + "A futuristic cityscape", + model="openrouter/google/gemini-2.5-flash-image-preview" + ) + ``` + """ + return await self.ai_handler.ai_generate_image( + prompt=prompt, + model=model, + size=size, + quality=quality, + style=style, + response_format=response_format, + **kwargs, + ) + + async def ai_generate_audio( # pragma: no cover - relies on external audio services + self, + text: str, + model: Optional[str] = None, + voice: str = "alloy", + format: str = "wav", + speed: float = 1.0, + **kwargs, + ) -> "MultimodalResponse": + """ + Generate audio/speech from text (Text-to-Speech). + + This is a dedicated method for audio generation with a clearer name. + Returns a MultimodalResponse containing the generated audio. + + Supported Providers: + - OpenAI TTS: Models like "tts-1", "tts-1-hd", "gpt-4o-mini-tts" + + Args: + text (str): Text to convert to speech. + model (str, optional): TTS model to use (defaults to AIConfig.audio_model). + voice (str): Voice to use ("alloy", "echo", "fable", "onyx", "nova", "shimmer"). + format (str): Audio format ("wav", "mp3", "opus", "aac", "flac", "pcm"). + speed (float): Speech speed multiplier (0.25 to 4.0). + **kwargs: Provider-specific parameters. + + Returns: + MultimodalResponse: Response with .audio containing AudioOutput. + + Example: + ```python + # Basic speech generation + result = await app.ai_generate_audio("Hello, how are you today?") + if result.has_audio: + result.audio.save("greeting.wav") + + # High-quality TTS + result = await app.ai_generate_audio( + "Welcome to the presentation.", + model="tts-1-hd", + voice="nova" + ) + ``` + """ + return await self.ai_handler.ai_generate_audio( + text=text, + model=model, + voice=voice, + format=format, + speed=speed, + **kwargs, + ) + async def call(self, target: str, *args, **kwargs) -> dict: """ Initiates a cross-agent call to another reasoner or skill via the AgentField execution gateway. diff --git a/sdk/python/agentfield/agent_ai.py b/sdk/python/agentfield/agent_ai.py index 94e4cbf..181a37e 100644 --- a/sdk/python/agentfield/agent_ai.py +++ b/sdk/python/agentfield/agent_ai.py @@ -1,7 +1,12 @@ +from __future__ import annotations + import json import os import re -from typing import Any, Dict, List, Literal, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Type, Union + +if TYPE_CHECKING: + from agentfield.multimodal_response import MultimodalResponse import requests from agentfield.agent_utils import AgentUtils @@ -76,6 +81,23 @@ def __init__(self, agent_instance): self.agent = agent_instance self._initialization_complete = False self._rate_limiter = None + self._fal_provider_instance = None + + @property + def _fal_provider(self): + """ + Lazy-initialized Fal provider for image, audio, and video generation. + + Returns: + FalProvider: Configured Fal.ai provider instance + """ + if self._fal_provider_instance is None: + from agentfield.media_providers import FalProvider + + self._fal_provider_instance = FalProvider( + api_key=self.agent.ai_config.fal_api_key + ) + return self._fal_provider_instance def _get_rate_limiter(self) -> StatelessRateLimiter: """ @@ -819,6 +841,21 @@ async def ai_with_audio( self.agent.ai_config.audio_model ) # Use configured audio model (defaults to tts-1) + # Route based on model prefix - Fal TTS models + if model.startswith("fal-ai/") or model.startswith("fal/"): + # Combine all text inputs + text_input = " ".join(str(arg) for arg in args if isinstance(arg, str)) + if not text_input: + text_input = "Hello, this is a test audio message." + + return await self._fal_provider.generate_audio( + text=text_input, + model=model, + voice=voice, + format=format, + **kwargs, + ) + # Check if mode="openai_direct" is specified if mode == "openai_direct": # Use direct OpenAI client with streaming response @@ -1087,7 +1124,16 @@ async def ai_with_vision( model = "dall-e-3" # Default image model # Route based on model prefix - if model.startswith("openrouter/"): + if model.startswith("fal-ai/") or model.startswith("fal/"): + # Fal: Use FalProvider for Flux, SDXL, Recraft, etc. + return await self._fal_provider.generate_image( + prompt=prompt, + model=model, + size=size, + quality=quality, + **kwargs, + ) + elif model.startswith("openrouter/"): # OpenRouter: Use chat completions API with image modality return await vision.generate_image_openrouter( prompt=prompt, @@ -1155,3 +1201,287 @@ async def ai_with_multimodal( final_kwargs = {**multimodal_params, **kwargs} return await self.ai(*args, model=model, **final_kwargs) + + async def ai_generate_image( + self, + prompt: str, + model: Optional[str] = None, + size: str = "1024x1024", + quality: str = "standard", + style: Optional[str] = None, + response_format: str = "url", + **kwargs, + ) -> "MultimodalResponse": + """ + Generate an image from a text prompt. + + This is a dedicated method for image generation with a clearer name + than ai_with_vision. Returns a MultimodalResponse containing the + generated image(s). + + Supported Providers: + - LiteLLM: DALL-E models like "dall-e-3", "dall-e-2" + - OpenRouter: Models like "openrouter/google/gemini-2.5-flash-image-preview" + - Fal.ai: Models like "fal-ai/flux/dev", "fal-ai/flux/schnell", "fal-ai/recraft-v3" + + Args: + prompt: Text description of the image to generate + model: Model to use (defaults to AIConfig.vision_model, typically "dall-e-3") + size: Image dimensions (e.g., "1024x1024", "1792x1024") or Fal presets + ("square_hd", "landscape_16_9", "portrait_4_3") + quality: Image quality ("standard" or "hd") + style: Image style for DALL-E 3 ("vivid" or "natural") + response_format: Output format ("url" or "b64_json") + **kwargs: Provider-specific parameters (e.g., image_config for OpenRouter) + + Returns: + MultimodalResponse: Response object with .images list containing ImageOutput objects. + - Use response.has_images to check if generation succeeded + - Use response.images[0].save("path.png") to save the image + - Use response.images[0].get_bytes() to get raw image bytes + + Examples: + # Basic image generation + result = await app.ai_generate_image("A sunset over mountains") + if result.has_images: + result.images[0].save("sunset.png") + + # OpenRouter with Gemini + result = await app.ai_generate_image( + "A futuristic cityscape at night", + model="openrouter/google/gemini-2.5-flash-image-preview", + image_config={"aspect_ratio": "16:9"} + ) + + # High quality DALL-E 3 + result = await app.ai_generate_image( + "A photorealistic portrait", + model="dall-e-3", + quality="hd", + style="natural" + ) + + # Fal.ai Flux (fast, high quality) + result = await app.ai_generate_image( + "A cyberpunk cityscape", + model="fal-ai/flux/dev", + size="landscape_16_9", + num_images=2 + ) + + # Fal.ai Flux Schnell (fastest) + result = await app.ai_generate_image( + "A serene Japanese garden", + model="fal-ai/flux/schnell", + size="square_hd" + ) + """ + # Use configured vision/image model as default + if model is None: + model = self.agent.ai_config.vision_model + + return await self.ai_with_vision( + prompt=prompt, + model=model, + size=size, + quality=quality, + style=style, + response_format=response_format, + **kwargs, + ) + + async def ai_generate_audio( + self, + text: str, + model: Optional[str] = None, + voice: str = "alloy", + format: str = "wav", + speed: float = 1.0, + **kwargs, + ) -> "MultimodalResponse": + """ + Generate audio/speech from text (Text-to-Speech). + + This is a dedicated method for audio generation with a clearer name + than ai_with_audio. Returns a MultimodalResponse containing the + generated audio. + + Supported Providers: + - LiteLLM: OpenAI TTS models like "tts-1", "tts-1-hd", "gpt-4o-mini-tts" + - Fal.ai: TTS models like "fal-ai/kokoro/..." (custom deployments) + + Args: + text: Text to convert to speech + model: TTS model to use (defaults to AIConfig.audio_model, typically "tts-1") + voice: Voice to use ("alloy", "echo", "fable", "onyx", "nova", "shimmer") + format: Audio format ("wav", "mp3", "opus", "aac", "flac", "pcm") + speed: Speech speed multiplier (0.25 to 4.0) + **kwargs: Provider-specific parameters + + Returns: + MultimodalResponse: Response object with .audio containing AudioOutput. + - Use response.has_audio to check if generation succeeded + - Use response.audio.save("path.wav") to save the audio + - Use response.audio.get_bytes() to get raw audio bytes + - Use response.audio.play() to play the audio (requires pygame) + + Examples: + # Basic speech generation + result = await app.ai_generate_audio("Hello, how are you today?") + if result.has_audio: + result.audio.save("greeting.wav") + + # High-quality TTS with custom voice + result = await app.ai_generate_audio( + "Welcome to the presentation.", + model="tts-1-hd", + voice="nova", + format="mp3" + ) + + # Adjust speech speed + result = await app.ai_generate_audio( + "This is spoken slowly.", + speed=0.75 + ) + """ + # Use configured audio model as default + if model is None: + model = self.agent.ai_config.audio_model + + return await self.ai_with_audio( + text, + model=model, + voice=voice, + format=format, + speed=speed, + **kwargs, + ) + + async def ai_generate_video( + self, + prompt: str, + model: Optional[str] = None, + image_url: Optional[str] = None, + duration: Optional[float] = None, + **kwargs, + ) -> "MultimodalResponse": + """ + Generate video from text or image. + + This method generates videos using Fal.ai's video generation models. + Supports both text-to-video and image-to-video generation. + + Supported Providers: + - Fal.ai: Models like "fal-ai/minimax-video/image-to-video", + "fal-ai/kling-video/v1/standard", "fal-ai/luma-dream-machine" + + Args: + prompt: Text description for the video + model: Video model to use (defaults to AIConfig.video_model) + image_url: Optional input image URL for image-to-video models + duration: Video duration in seconds (model-dependent) + **kwargs: Provider-specific parameters + + Returns: + MultimodalResponse: Response with .files containing the video. + - Use response.files[0].save("video.mp4") to save + - Use response.files[0].url to get the video URL + + Examples: + # Image to video + result = await app.ai_generate_video( + "Camera slowly pans across the landscape", + model="fal-ai/minimax-video/image-to-video", + image_url="https://example.com/image.jpg" + ) + result.files[0].save("output.mp4") + + # Text to video + result = await app.ai_generate_video( + "A cat playing with yarn", + model="fal-ai/kling-video/v1/standard" + ) + + # Luma Dream Machine + result = await app.ai_generate_video( + "A dreamy underwater scene", + model="fal-ai/luma-dream-machine" + ) + """ + if model is None: + model = self.agent.ai_config.video_model + + # Currently only Fal supports video generation + if not (model.startswith("fal-ai/") or model.startswith("fal/")): + raise ValueError( + f"Video generation currently only supports Fal.ai models. " + f"Use models like 'fal-ai/minimax-video/image-to-video'. Got: {model}" + ) + + return await self._fal_provider.generate_video( + prompt=prompt, + model=model, + image_url=image_url, + duration=duration, + **kwargs, + ) + + async def ai_transcribe_audio( + self, + audio_url: str, + model: str = "fal-ai/whisper", + language: Optional[str] = None, + **kwargs, + ) -> "MultimodalResponse": + """ + Transcribe audio to text (Speech-to-Text). + + This method transcribes audio files to text using Fal.ai's Whisper models. + + Supported Providers: + - Fal.ai: Models like "fal-ai/whisper", "fal-ai/wizper" (2x faster) + + Args: + audio_url: URL to audio file to transcribe + model: STT model to use (defaults to "fal-ai/whisper") + language: Optional language hint (e.g., "en", "es", "fr") + **kwargs: Provider-specific parameters + + Returns: + MultimodalResponse: Response with .text containing the transcription. + - Use response.text to get the transcribed text + + Examples: + # Basic transcription + result = await app.ai_transcribe_audio( + "https://example.com/audio.mp3" + ) + print(result.text) + + # With language hint + result = await app.ai_transcribe_audio( + "https://example.com/spanish_audio.mp3", + model="fal-ai/whisper", + language="es" + ) + + # Fast transcription with Wizper + result = await app.ai_transcribe_audio( + "https://example.com/audio.mp3", + model="fal-ai/wizper" + ) + """ + # Currently only Fal supports transcription + if not (model.startswith("fal-ai/") or model.startswith("fal/")): + raise ValueError( + f"Audio transcription currently only supports Fal.ai models. " + f"Use 'fal-ai/whisper' or 'fal-ai/wizper'. Got: {model}" + ) + + return await self._fal_provider.transcribe_audio( + audio_url=audio_url, + model=model, + language=language, + **kwargs, + ) diff --git a/sdk/python/agentfield/media_providers.py b/sdk/python/agentfield/media_providers.py new file mode 100644 index 0000000..0167ab9 --- /dev/null +++ b/sdk/python/agentfield/media_providers.py @@ -0,0 +1,825 @@ +""" +Media Provider Abstraction for AgentField + +Provides a unified interface for different media generation backends: +- Fal.ai (Flux, SDXL, Whisper, TTS, Video models) +- OpenRouter (via LiteLLM) +- OpenAI DALL-E (via LiteLLM) +- Future: ElevenLabs, Replicate, etc. + +Each provider implements the same interface, making it easy to swap +backends or add new ones without changing agent code. +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Literal, Optional, Union + +from agentfield.multimodal_response import ( + AudioOutput, + FileOutput, + ImageOutput, + MultimodalResponse, +) + + +# Fal image size presets +FalImageSize = Literal[ + "square_hd", # 1024x1024 + "square", # 512x512 + "portrait_4_3", # 768x1024 + "portrait_16_9", # 576x1024 + "landscape_4_3", # 1024x768 + "landscape_16_9", # 1024x576 +] + + +class MediaProvider(ABC): + """ + Abstract base class for media generation providers. + + Subclass this to add support for new image/audio generation backends. + """ + + @property + @abstractmethod + def name(self) -> str: + """Provider name for identification.""" + pass + + @property + @abstractmethod + def supported_modalities(self) -> List[str]: + """List of supported modalities: 'image', 'audio', 'video'.""" + pass + + @abstractmethod + async def generate_image( + self, + prompt: str, + model: Optional[str] = None, + size: str = "1024x1024", + quality: str = "standard", + **kwargs, + ) -> MultimodalResponse: + """ + Generate an image from a text prompt. + + Args: + prompt: Text description of the image + model: Model to use (provider-specific) + size: Image dimensions or preset + quality: Quality level + **kwargs: Provider-specific options + + Returns: + MultimodalResponse with generated image(s) + """ + pass + + @abstractmethod + async def generate_audio( + self, + text: str, + model: Optional[str] = None, + voice: str = "alloy", + format: str = "wav", + **kwargs, + ) -> MultimodalResponse: + """ + Generate audio/speech from text. + + Args: + text: Text to convert to speech + model: TTS model to use + voice: Voice identifier + format: Audio format + **kwargs: Provider-specific options + + Returns: + MultimodalResponse with generated audio + """ + pass + + async def generate_video( + self, + prompt: str, + model: Optional[str] = None, + image_url: Optional[str] = None, + **kwargs, + ) -> MultimodalResponse: + """ + Generate video from text or image. + + Args: + prompt: Text description for video + model: Video model to use + image_url: Optional input image for image-to-video + **kwargs: Provider-specific options + + Returns: + MultimodalResponse with generated video + """ + raise NotImplementedError(f"{self.name} does not support video generation") + + +class FalProvider(MediaProvider): + """ + Fal.ai provider for image, audio, and video generation. + + Image Models: + - fal-ai/flux/dev - FLUX.1 [dev], 12B params, high quality (default) + - fal-ai/flux/schnell - FLUX.1 [schnell], fast 1-4 step generation + - fal-ai/flux-pro/v1.1-ultra - FLUX Pro Ultra, up to 2K resolution + - fal-ai/fast-sdxl - Fast SDXL + - fal-ai/recraft-v3 - SOTA text-to-image + - fal-ai/stable-diffusion-v35-large - SD 3.5 Large + + Video Models: + - fal-ai/minimax-video/image-to-video - Image to video + - fal-ai/luma-dream-machine - Luma Dream Machine + - fal-ai/kling-video/v1/standard - Kling 1.0 + + Audio Models: + - fal-ai/whisper - Speech to text + - Custom TTS deployments + + Requires FAL_KEY environment variable or explicit api_key. + + Example: + provider = FalProvider(api_key="...") + + # Generate image + result = await provider.generate_image( + "A sunset over mountains", + model="fal-ai/flux/dev", + image_size="landscape_16_9", + num_images=2 + ) + result.images[0].save("sunset.png") + + # Generate video from image + result = await provider.generate_video( + "Camera slowly pans across the scene", + model="fal-ai/minimax-video/image-to-video", + image_url="https://example.com/image.jpg" + ) + """ + + def __init__(self, api_key: Optional[str] = None): + """ + Initialize Fal provider. + + Args: + api_key: Fal.ai API key. If not provided, uses FAL_KEY env var. + """ + self._api_key = api_key + self._client = None + + @property + def name(self) -> str: + return "fal" + + @property + def supported_modalities(self) -> List[str]: + return ["image", "audio", "video"] + + def _get_client(self): + """Lazy initialization of fal client.""" + if self._client is None: + try: + import fal_client + + if self._api_key: + import os + os.environ["FAL_KEY"] = self._api_key + + self._client = fal_client + except ImportError: + raise ImportError( + "fal-client is not installed. Install it with: pip install fal-client" + ) + return self._client + + def _parse_image_size( + self, size: str + ) -> Union[str, Dict[str, int]]: + """ + Parse image size into fal format. + + Args: + size: Either a preset like "landscape_16_9" or dimensions like "1024x768" + + Returns: + Fal-compatible image_size (string preset or dict with width/height) + """ + # Check if it's a fal preset + fal_presets = { + "square_hd", "square", "portrait_4_3", "portrait_16_9", + "landscape_4_3", "landscape_16_9" + } + if size in fal_presets: + return size + + # Parse WxH format + if "x" in size.lower(): + parts = size.lower().split("x") + try: + width, height = int(parts[0]), int(parts[1]) + return {"width": width, "height": height} + except ValueError: + pass + + # Default to square_hd + return "square_hd" + + async def generate_image( + self, + prompt: str, + model: Optional[str] = None, + size: str = "square_hd", + quality: str = "standard", + num_images: int = 1, + seed: Optional[int] = None, + guidance_scale: Optional[float] = None, + num_inference_steps: Optional[int] = None, + **kwargs, + ) -> MultimodalResponse: + """ + Generate image using Fal.ai. + + Args: + prompt: Text prompt for image generation + model: Fal model ID (defaults to "fal-ai/flux/dev") + size: Image size - preset ("square_hd", "landscape_16_9") or "WxH" + quality: "standard" (25 steps) or "hd" (50 steps) + num_images: Number of images to generate (1-4) + seed: Random seed for reproducibility + guidance_scale: Guidance scale for generation + num_inference_steps: Override inference steps + **kwargs: Additional fal-specific parameters + + Returns: + MultimodalResponse with generated images + + Example: + result = await provider.generate_image( + "A cyberpunk cityscape at night", + model="fal-ai/flux/dev", + size="landscape_16_9", + num_images=2, + seed=42 + ) + """ + client = self._get_client() + + # Default model + if model is None: + model = "fal-ai/flux/dev" + + # Parse image size + image_size = self._parse_image_size(size) + + # Determine inference steps based on quality + if num_inference_steps is None: + num_inference_steps = 25 if quality == "standard" else 50 + + # Build request arguments + fal_args: Dict[str, Any] = { + "prompt": prompt, + "image_size": image_size, + "num_images": num_images, + "num_inference_steps": num_inference_steps, + } + + # Add optional parameters + if seed is not None: + fal_args["seed"] = seed + if guidance_scale is not None: + fal_args["guidance_scale"] = guidance_scale + + # Merge any additional kwargs + fal_args.update(kwargs) + + try: + # Use subscribe_async for queue-based reliable execution + result = await client.subscribe_async( + model, + arguments=fal_args, + with_logs=False, + ) + + # Extract images from result + images = [] + if "images" in result: + for img_data in result["images"]: + url = img_data.get("url") + # width, height, content_type available but not used currently + # _width = img_data.get("width") + # _height = img_data.get("height") + # _content_type = img_data.get("content_type", "image/png") + + if url: + images.append( + ImageOutput( + url=url, + b64_json=None, + revised_prompt=prompt, + ) + ) + + # Also check for single image response + if "image" in result and not images: + img_data = result["image"] + url = img_data.get("url") if isinstance(img_data, dict) else img_data + if url: + images.append( + ImageOutput(url=url, b64_json=None, revised_prompt=prompt) + ) + + return MultimodalResponse( + text=prompt, + audio=None, + images=images, + files=[], + raw_response=result, + ) + + except Exception as e: + from agentfield.logger import log_error + log_error(f"Fal image generation failed: {e}") + raise + + async def generate_audio( + self, + text: str, + model: Optional[str] = None, + voice: Optional[str] = None, + format: str = "wav", + ref_audio_url: Optional[str] = None, + speed: float = 1.0, + **kwargs, + ) -> MultimodalResponse: + """ + Generate audio using Fal.ai TTS models. + + For voice cloning, provide a ref_audio_url with a sample of the voice. + + Args: + text: Text to convert to speech + model: Fal TTS model (provider-specific) + voice: Voice identifier or preset + format: Audio format (wav, mp3) + ref_audio_url: URL to reference audio for voice cloning + speed: Speech speed multiplier + **kwargs: Additional fal-specific parameters (gen_text, ref_text, etc.) + + Returns: + MultimodalResponse with generated audio + + Note: + Fal has various TTS models with different APIs. Check the specific + model documentation for available parameters. + """ + client = self._get_client() + + # Build request arguments based on model + fal_args: Dict[str, Any] = {} + + # Common patterns for fal TTS models + if "gen_text" not in kwargs: + fal_args["gen_text"] = text + if ref_audio_url: + fal_args["ref_audio_url"] = ref_audio_url + if voice and voice.startswith("http"): + fal_args["ref_audio_url"] = voice + + # Merge additional kwargs + fal_args.update(kwargs) + + try: + result = await client.subscribe_async( + model, + arguments=fal_args, + with_logs=False, + ) + + # Extract audio from result - fal returns audio in various formats + audio = None + audio_url = None + + # Check common response patterns + if "audio_url" in result: + audio_url = result["audio_url"] + elif "audio" in result: + audio_data = result["audio"] + if isinstance(audio_data, dict): + audio_url = audio_data.get("url") + elif isinstance(audio_data, str): + audio_url = audio_data + + if audio_url: + audio = AudioOutput( + url=audio_url, + data=None, + format=format, + ) + + return MultimodalResponse( + text=text, + audio=audio, + images=[], + files=[], + raw_response=result, + ) + + except Exception as e: + from agentfield.logger import log_error + log_error(f"Fal audio generation failed: {e}") + raise + + async def generate_video( + self, + prompt: str, + model: Optional[str] = None, + image_url: Optional[str] = None, + duration: Optional[float] = None, + **kwargs, + ) -> MultimodalResponse: + """ + Generate video using Fal.ai video models. + + Args: + prompt: Text description for the video + model: Fal video model (defaults to "fal-ai/minimax-video/image-to-video") + image_url: Input image URL for image-to-video models + duration: Video duration in seconds (model-dependent) + **kwargs: Additional fal-specific parameters + + Returns: + MultimodalResponse with video in files list + + Example: + # Image to video + result = await provider.generate_video( + "Camera slowly pans across the mountain landscape", + model="fal-ai/minimax-video/image-to-video", + image_url="https://example.com/mountain.jpg" + ) + + # Text to video + result = await provider.generate_video( + "A cat playing with yarn", + model="fal-ai/kling-video/v1/standard" + ) + """ + client = self._get_client() + + # Default model + if model is None: + model = "fal-ai/minimax-video/image-to-video" + + # Build request arguments + fal_args: Dict[str, Any] = { + "prompt": prompt, + } + + if image_url: + fal_args["image_url"] = image_url + if duration: + fal_args["duration"] = duration + + # Merge additional kwargs + fal_args.update(kwargs) + + try: + result = await client.subscribe_async( + model, + arguments=fal_args, + with_logs=False, + ) + + # Extract video from result + files = [] + video_url = None + + # Check common response patterns + if "video_url" in result: + video_url = result["video_url"] + elif "video" in result: + video_data = result["video"] + if isinstance(video_data, dict): + video_url = video_data.get("url") + elif isinstance(video_data, str): + video_url = video_data + + if video_url: + files.append( + FileOutput( + url=video_url, + data=None, + mime_type="video/mp4", + filename="generated_video.mp4", + ) + ) + + return MultimodalResponse( + text=prompt, + audio=None, + images=[], + files=files, + raw_response=result, + ) + + except Exception as e: + from agentfield.logger import log_error + log_error(f"Fal video generation failed: {e}") + raise + + async def transcribe_audio( + self, + audio_url: str, + model: str = "fal-ai/whisper", + language: Optional[str] = None, + **kwargs, + ) -> MultimodalResponse: + """ + Transcribe audio to text using Fal's Whisper model. + + Args: + audio_url: URL to audio file to transcribe + model: Whisper model (defaults to "fal-ai/whisper") + language: Optional language hint + **kwargs: Additional parameters + + Returns: + MultimodalResponse with transcribed text + """ + client = self._get_client() + + fal_args: Dict[str, Any] = { + "audio_url": audio_url, + } + if language: + fal_args["language"] = language + fal_args.update(kwargs) + + try: + result = await client.subscribe_async( + model, + arguments=fal_args, + with_logs=False, + ) + + # Extract text from result + text = "" + if "text" in result: + text = result["text"] + elif "transcription" in result: + text = result["transcription"] + + return MultimodalResponse( + text=text, + audio=None, + images=[], + files=[], + raw_response=result, + ) + + except Exception as e: + from agentfield.logger import log_error + log_error(f"Fal transcription failed: {e}") + raise + + +class LiteLLMProvider(MediaProvider): + """ + LiteLLM-based provider for OpenAI, Azure, and other LiteLLM-supported backends. + + Uses LiteLLM's image_generation and speech APIs. + + Image Models: + - dall-e-3 - OpenAI DALL-E 3 + - dall-e-2 - OpenAI DALL-E 2 + - azure/dall-e-3 - Azure DALL-E + + Audio Models: + - tts-1 - OpenAI TTS + - tts-1-hd - OpenAI TTS HD + - gpt-4o-mini-tts - GPT-4o Mini TTS + """ + + def __init__(self, api_key: Optional[str] = None): + self._api_key = api_key + + @property + def name(self) -> str: + return "litellm" + + @property + def supported_modalities(self) -> List[str]: + return ["image", "audio"] + + async def generate_image( + self, + prompt: str, + model: Optional[str] = None, + size: str = "1024x1024", + quality: str = "standard", + style: Optional[str] = None, + response_format: str = "url", + **kwargs, + ) -> MultimodalResponse: + """Generate image using LiteLLM (DALL-E, Azure DALL-E, etc.).""" + from agentfield import vision + + model = model or "dall-e-3" + + return await vision.generate_image_litellm( + prompt=prompt, + model=model, + size=size, + quality=quality, + style=style, + response_format=response_format, + **kwargs, + ) + + async def generate_audio( + self, + text: str, + model: Optional[str] = None, + voice: str = "alloy", + format: str = "wav", + speed: float = 1.0, + **kwargs, + ) -> MultimodalResponse: + """Generate audio using LiteLLM TTS.""" + try: + import litellm + + litellm.suppress_debug_info = True + except ImportError: + raise ImportError( + "litellm is not installed. Install it with: pip install litellm" + ) + + model = model or "tts-1" + + try: + response = await litellm.aspeech( + model=model, + input=text, + voice=voice, + speed=speed, + **kwargs, + ) + + # Extract audio data + audio_data = None + if hasattr(response, "content"): + import base64 + + audio_data = base64.b64encode(response.content).decode("utf-8") + + audio = AudioOutput( + data=audio_data, + format=format, + url=None, + ) + + return MultimodalResponse( + text=text, + audio=audio, + images=[], + files=[], + raw_response=response, + ) + + except Exception as e: + from agentfield.logger import log_error + + log_error(f"LiteLLM audio generation failed: {e}") + raise + + +class OpenRouterProvider(MediaProvider): + """ + OpenRouter provider for image generation via chat completions. + + Uses the modalities parameter with chat completions API for image generation. + + Supports models like: + - google/gemini-2.5-flash-image-preview + - Other OpenRouter models with image generation capabilities + """ + + def __init__(self, api_key: Optional[str] = None): + self._api_key = api_key + + @property + def name(self) -> str: + return "openrouter" + + @property + def supported_modalities(self) -> List[str]: + return ["image"] # OpenRouter primarily supports image generation + + async def generate_image( + self, + prompt: str, + model: Optional[str] = None, + size: str = "1024x1024", + quality: str = "standard", + **kwargs, + ) -> MultimodalResponse: + """Generate image using OpenRouter's chat completions API.""" + from agentfield import vision + + model = model or "openrouter/google/gemini-2.5-flash-image-preview" + + # Ensure model has openrouter prefix + if not model.startswith("openrouter/"): + model = f"openrouter/{model}" + + return await vision.generate_image_openrouter( + prompt=prompt, + model=model, + size=size, + quality=quality, + style=None, + response_format="url", + **kwargs, + ) + + async def generate_audio( + self, + text: str, + model: Optional[str] = None, + voice: str = "alloy", + format: str = "wav", + **kwargs, + ) -> MultimodalResponse: + """OpenRouter doesn't support TTS directly.""" + raise NotImplementedError( + "OpenRouter doesn't support audio generation. Use LiteLLMProvider or FalProvider." + ) + + +# Provider registry for easy access +_PROVIDERS: Dict[str, type] = { + "fal": FalProvider, + "litellm": LiteLLMProvider, + "openrouter": OpenRouterProvider, +} + + +def get_provider(name: str, **kwargs) -> MediaProvider: + """ + Get a media provider instance by name. + + Args: + name: Provider name ('fal', 'litellm', 'openrouter') + **kwargs: Provider-specific initialization arguments + + Returns: + MediaProvider instance + + Example: + # Fal provider for Flux + provider = get_provider("fal", api_key="...") + result = await provider.generate_image( + "A sunset over mountains", + model="fal-ai/flux/dev" + ) + + # LiteLLM provider for DALL-E + provider = get_provider("litellm") + result = await provider.generate_image( + "A sunset over mountains", + model="dall-e-3" + ) + """ + if name not in _PROVIDERS: + raise ValueError( + f"Unknown provider: {name}. Available: {list(_PROVIDERS.keys())}" + ) + return _PROVIDERS[name](**kwargs) + + +def register_provider(name: str, provider_class: type): + """ + Register a custom media provider. + + Args: + name: Provider name for lookup + provider_class: MediaProvider subclass + + Example: + class ReplicateProvider(MediaProvider): + ... + + register_provider("replicate", ReplicateProvider) + """ + if not issubclass(provider_class, MediaProvider): + raise TypeError("provider_class must be a MediaProvider subclass") + _PROVIDERS[name] = provider_class diff --git a/sdk/python/agentfield/multimodal_response.py b/sdk/python/agentfield/multimodal_response.py index 26ddc6d..c1e4ae3 100644 --- a/sdk/python/agentfield/multimodal_response.py +++ b/sdk/python/agentfield/multimodal_response.py @@ -323,6 +323,112 @@ def save_all( return saved_files +def _extract_image_from_data(data: Any) -> Optional[ImageOutput]: + """ + Extract an ImageOutput from various data structures. + Handles multiple formats: OpenRouter, OpenAI, and generic patterns. + """ + if data is None: + return None + + # Direct url/b64_json attributes (standard image generation) + if hasattr(data, "url") or hasattr(data, "b64_json"): + url = getattr(data, "url", None) + b64 = getattr(data, "b64_json", None) + if url or b64: + return ImageOutput( + url=url, + b64_json=b64, + revised_prompt=getattr(data, "revised_prompt", None), + ) + + # OpenRouter/Gemini pattern: {"type": "image_url", "image_url": {"url": "..."}} + if hasattr(data, "image_url"): + image_url_obj = data.image_url + url = getattr(image_url_obj, "url", None) if hasattr(image_url_obj, "url") else None + if url: + # Handle data URLs (base64 encoded) + if url.startswith("data:image"): + # Extract base64 from data URL + try: + b64_data = url.split(",", 1)[1] if "," in url else None + return ImageOutput(url=url, b64_json=b64_data, revised_prompt=None) + except Exception: + return ImageOutput(url=url, b64_json=None, revised_prompt=None) + return ImageOutput(url=url, b64_json=None, revised_prompt=None) + + # Dict-based patterns + if isinstance(data, dict): + # Direct url/b64_json keys + if "url" in data or "b64_json" in data: + url = data.get("url") + b64 = data.get("b64_json") + if url or b64: + return ImageOutput( + url=url, + b64_json=b64, + revised_prompt=data.get("revised_prompt"), + ) + + # OpenRouter dict pattern: {"image_url": {"url": "..."}} + if "image_url" in data: + image_url_data = data["image_url"] + if isinstance(image_url_data, dict): + url = image_url_data.get("url") + if url: + # Handle data URLs + if url.startswith("data:image"): + try: + b64_data = url.split(",", 1)[1] if "," in url else None + return ImageOutput(url=url, b64_json=b64_data, revised_prompt=None) + except Exception: + return ImageOutput(url=url, b64_json=None, revised_prompt=None) + return ImageOutput(url=url, b64_json=None, revised_prompt=None) + + return None + + +def _find_images_recursive(obj: Any, max_depth: int = 10) -> List[ImageOutput]: + """ + Recursively search any structure for image data. + This is a generalized fallback that handles unknown response formats. + """ + if max_depth <= 0: + return [] + + images = [] + + # Try direct extraction first + img = _extract_image_from_data(obj) + if img: + images.append(img) + return images # Found at this level, don't recurse deeper + + # Handle lists/tuples + if isinstance(obj, (list, tuple)): + for item in obj: + images.extend(_find_images_recursive(item, max_depth - 1)) + + # Handle dicts + elif isinstance(obj, dict): + for value in obj.values(): + images.extend(_find_images_recursive(value, max_depth - 1)) + + # Handle objects with attributes + elif hasattr(obj, "__dict__"): + for attr_name in dir(obj): + if attr_name.startswith("_"): + continue + try: + attr_val = getattr(obj, attr_name, None) + if attr_val is not None and not callable(attr_val): + images.extend(_find_images_recursive(attr_val, max_depth - 1)) + except Exception: + continue + + return images + + def detect_multimodal_response(response: Any) -> MultimodalResponse: """ Automatically detect and wrap multimodal content from LiteLLM responses. @@ -338,7 +444,7 @@ def detect_multimodal_response(response: Any) -> MultimodalResponse: images = [] files = [] - # Handle completion responses (text + potential audio) + # Handle completion responses (text + potential audio + potential images) if hasattr(response, "choices") and response.choices: choice = response.choices[0] message = choice.message @@ -357,6 +463,13 @@ def detect_multimodal_response(response: Any) -> MultimodalResponse: url=None, ) + # Extract images from completion responses (OpenRouter/Gemini pattern) + if hasattr(message, "images") and message.images: + for img_data in message.images: + img = _extract_image_from_data(img_data) + if img: + images.append(img) + # Handle image generation responses elif hasattr(response, "data") and response.data: # This is likely an image generation response @@ -398,6 +511,11 @@ def detect_multimodal_response(response: Any) -> MultimodalResponse: else: text = str(response) + # Fallback: if no images found yet, try recursive search + # This catches edge cases where images are in unexpected locations + if not images: + images = _find_images_recursive(response, max_depth=5) + return MultimodalResponse( text=text, audio=audio, images=images, files=files, raw_response=response ) diff --git a/sdk/python/agentfield/types.py b/sdk/python/agentfield/types.py index e84a6eb..385de85 100644 --- a/sdk/python/agentfield/types.py +++ b/sdk/python/agentfield/types.py @@ -1,6 +1,6 @@ from dataclasses import asdict, dataclass, field from typing import Any, Dict, List, Literal, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, computed_field from enum import Enum @@ -338,10 +338,27 @@ class AIConfig(BaseModel): image_quality: Literal["low", "high"] = Field( default="high", description="Quality for image generation/processing." ) + audio_format: str = Field( default="wav", description="Default format for audio output (wav, mp3)." ) + # Fal.ai settings + fal_api_key: Optional[str] = Field( + default=None, + description="Fal.ai API key. If not set, uses FAL_KEY environment variable.", + ) + video_model: str = Field( + default="fal-ai/minimax-video/image-to-video", + description="Default model for video generation.", + ) + + @computed_field + @property + def image_model(self) -> str: + """Alias for vision_model - clearer name for image generation model.""" + return self.vision_model + # Behavior settings timeout: Optional[int] = Field( default=None, @@ -490,6 +507,7 @@ async def get_model_limits(self, model: Optional[str] = None) -> Dict[str, Any]: try: import litellm + litellm.suppress_debug_info = True # Fetch model info once and cache it info = litellm.get_model_info(target_model) diff --git a/sdk/python/tests/test_media_providers.py b/sdk/python/tests/test_media_providers.py new file mode 100644 index 0000000..8d96b6f --- /dev/null +++ b/sdk/python/tests/test_media_providers.py @@ -0,0 +1,593 @@ +""" +Tests for Media Providers and Unified Multimodal UX. + +This module tests: +- FalProvider, LiteLLMProvider, OpenRouterProvider +- Provider routing in AgentAI (fal-ai/, openrouter/, default) +- New methods: ai_generate_video, ai_transcribe_audio +- AIConfig extensions: fal_api_key, video_model +""" + +import copy +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentfield.agent_ai import AgentAI +from agentfield.media_providers import ( + FalProvider, + LiteLLMProvider, + OpenRouterProvider, + MediaProvider, + get_provider, + register_provider, +) +from agentfield.multimodal_response import ( + AudioOutput, + ImageOutput, + FileOutput, + MultimodalResponse, +) +from agentfield.types import AIConfig + + +# ============================================================================= +# Test Fixtures +# ============================================================================= + + +class DummyAIConfig: + """Dummy AIConfig for testing.""" + + def __init__(self): + self.model = "openai/gpt-4" + self.temperature = 0.1 + self.max_tokens = 100 + self.top_p = 1.0 + self.stream = False + self.response_format = "auto" + self.fallback_models = [] + self.final_fallback_model = None + self.enable_rate_limit_retry = True + self.rate_limit_max_retries = 2 + self.rate_limit_base_delay = 0.1 + self.rate_limit_max_delay = 1.0 + self.rate_limit_jitter_factor = 0.1 + self.rate_limit_circuit_breaker_threshold = 3 + self.rate_limit_circuit_breaker_timeout = 1 + self.auto_inject_memory = [] + self.model_limits_cache = {} + self.audio_model = "tts-1" + self.vision_model = "dall-e-3" + # New fields for Fal support + self.fal_api_key = None + self.video_model = "fal-ai/minimax-video/image-to-video" + + def copy(self, deep=False): + return copy.deepcopy(self) + + async def get_model_limits(self, model=None): + return {"context_length": 1000, "max_output_tokens": 100} + + def get_litellm_params(self, **overrides): + params = { + "model": self.model, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "stream": self.stream, + } + params.update(overrides) + return params + + +class StubAgent: + """Stub Agent for testing.""" + + def __init__(self): + self.node_id = "test-agent" + self.ai_config = DummyAIConfig() + self.memory = SimpleNamespace() + + +@pytest.fixture +def agent_with_ai(): + """Create a stub agent with AI config.""" + agent = StubAgent() + return agent + + +@pytest.fixture +def fal_provider(): + """Create a FalProvider instance.""" + return FalProvider(api_key="test-fal-key") + + +@pytest.fixture +def litellm_provider(): + """Create a LiteLLMProvider instance.""" + return LiteLLMProvider() + + +@pytest.fixture +def openrouter_provider(): + """Create an OpenRouterProvider instance.""" + return OpenRouterProvider() + + +# ============================================================================= +# AIConfig Tests - New Fields +# ============================================================================= + + +class TestAIConfigFalFields: + """Test new Fal-related fields in AIConfig.""" + + def test_aiconfig_has_fal_api_key(self): + """AIConfig should have fal_api_key field.""" + config = AIConfig() + assert hasattr(config, "fal_api_key") + assert config.fal_api_key is None # Default is None + + def test_aiconfig_has_video_model(self): + """AIConfig should have video_model field.""" + config = AIConfig() + assert hasattr(config, "video_model") + assert config.video_model == "fal-ai/minimax-video/image-to-video" + + def test_aiconfig_fal_api_key_can_be_set(self): + """fal_api_key should be settable.""" + config = AIConfig(fal_api_key="my-fal-key") + assert config.fal_api_key == "my-fal-key" + + def test_aiconfig_video_model_can_be_overridden(self): + """video_model should be overridable.""" + config = AIConfig(video_model="fal-ai/kling-video/v1/standard") + assert config.video_model == "fal-ai/kling-video/v1/standard" + + +# ============================================================================= +# FalProvider Tests +# ============================================================================= + + +class TestFalProvider: + """Tests for FalProvider.""" + + def test_fal_provider_name(self, fal_provider): + """FalProvider should have correct name.""" + assert fal_provider.name == "fal" + + def test_fal_provider_supported_modalities(self, fal_provider): + """FalProvider should support image, audio, and video.""" + assert "image" in fal_provider.supported_modalities + assert "audio" in fal_provider.supported_modalities + assert "video" in fal_provider.supported_modalities + + def test_fal_provider_parse_image_size_preset(self, fal_provider): + """FalProvider should parse Fal presets correctly.""" + assert fal_provider._parse_image_size("square_hd") == "square_hd" + assert fal_provider._parse_image_size("landscape_16_9") == "landscape_16_9" + assert fal_provider._parse_image_size("portrait_4_3") == "portrait_4_3" + + def test_fal_provider_parse_image_size_dimensions(self, fal_provider): + """FalProvider should parse WxH dimensions correctly.""" + result = fal_provider._parse_image_size("1024x768") + assert result == {"width": 1024, "height": 768} + + result = fal_provider._parse_image_size("512x512") + assert result == {"width": 512, "height": 512} + + def test_fal_provider_parse_image_size_invalid_fallback(self, fal_provider): + """FalProvider should fallback to square_hd for invalid sizes.""" + assert fal_provider._parse_image_size("invalid") == "square_hd" + + @pytest.mark.asyncio + async def test_fal_provider_generate_image(self, fal_provider, monkeypatch): + """FalProvider.generate_image should call fal_client correctly.""" + mock_result = { + "images": [ + {"url": "https://fal.media/test.png", "width": 1024, "height": 1024} + ] + } + + mock_client = MagicMock() + mock_client.subscribe_async = AsyncMock(return_value=mock_result) + monkeypatch.setattr(fal_provider, "_client", mock_client) + + result = await fal_provider.generate_image( + prompt="A sunset", + model="fal-ai/flux/dev", + size="square_hd", + ) + + assert result.has_images + assert len(result.images) == 1 + assert result.images[0].url == "https://fal.media/test.png" + mock_client.subscribe_async.assert_called_once() + + @pytest.mark.asyncio + async def test_fal_provider_generate_video(self, fal_provider, monkeypatch): + """FalProvider.generate_video should call fal_client correctly.""" + mock_result = {"video_url": "https://fal.media/video.mp4"} + + mock_client = MagicMock() + mock_client.subscribe_async = AsyncMock(return_value=mock_result) + monkeypatch.setattr(fal_provider, "_client", mock_client) + + result = await fal_provider.generate_video( + prompt="Camera pans", + model="fal-ai/minimax-video/image-to-video", + image_url="https://example.com/image.jpg", + ) + + assert len(result.files) == 1 + assert result.files[0].url == "https://fal.media/video.mp4" + + @pytest.mark.asyncio + async def test_fal_provider_transcribe_audio(self, fal_provider, monkeypatch): + """FalProvider.transcribe_audio should return transcription.""" + mock_result = {"text": "Hello world, this is a test."} + + mock_client = MagicMock() + mock_client.subscribe_async = AsyncMock(return_value=mock_result) + monkeypatch.setattr(fal_provider, "_client", mock_client) + + result = await fal_provider.transcribe_audio( + audio_url="https://example.com/audio.mp3", + model="fal-ai/whisper", + ) + + assert result.text == "Hello world, this is a test." + + +# ============================================================================= +# Provider Registry Tests +# ============================================================================= + + +class TestProviderRegistry: + """Tests for provider registry functions.""" + + def test_get_provider_fal(self): + """get_provider should return FalProvider for 'fal'.""" + provider = get_provider("fal") + assert isinstance(provider, FalProvider) + + def test_get_provider_litellm(self): + """get_provider should return LiteLLMProvider for 'litellm'.""" + provider = get_provider("litellm") + assert isinstance(provider, LiteLLMProvider) + + def test_get_provider_openrouter(self): + """get_provider should return OpenRouterProvider for 'openrouter'.""" + provider = get_provider("openrouter") + assert isinstance(provider, OpenRouterProvider) + + def test_get_provider_unknown_raises(self): + """get_provider should raise for unknown provider.""" + with pytest.raises(ValueError, match="Unknown provider"): + get_provider("unknown_provider") + + def test_register_custom_provider(self): + """register_provider should add custom providers.""" + + class CustomProvider(MediaProvider): + @property + def name(self): + return "custom" + + @property + def supported_modalities(self): + return ["image"] + + async def generate_image(self, prompt, **kwargs): + pass + + async def generate_audio(self, text, **kwargs): + pass + + register_provider("custom", CustomProvider) + provider = get_provider("custom") + assert isinstance(provider, CustomProvider) + + +# ============================================================================= +# AgentAI Provider Routing Tests +# ============================================================================= + + +class TestAgentAIProviderRouting: + """Tests for provider routing in AgentAI.""" + + def test_fal_provider_lazy_initialization(self, agent_with_ai): + """_fal_provider should be lazily initialized.""" + ai = AgentAI(agent_with_ai) + assert ai._fal_provider_instance is None + + # Access the property to trigger initialization + with patch("agentfield.media_providers.FalProvider") as mock_fal: + mock_fal.return_value = MagicMock() + provider = ai._fal_provider + assert provider is not None + mock_fal.assert_called_once_with(api_key=None) + + def test_fal_provider_cached(self, agent_with_ai): + """_fal_provider should be cached after first access.""" + ai = AgentAI(agent_with_ai) + + with patch("agentfield.media_providers.FalProvider") as mock_fal: + mock_provider = MagicMock() + mock_fal.return_value = mock_provider + + provider1 = ai._fal_provider + provider2 = ai._fal_provider + + # Should only be created once + assert mock_fal.call_count == 1 + assert provider1 is provider2 + + @pytest.mark.asyncio + async def test_ai_with_vision_routes_fal_ai_prefix(self, agent_with_ai, monkeypatch): + """ai_with_vision should route fal-ai/ models to FalProvider.""" + ai = AgentAI(agent_with_ai) + + mock_response = MultimodalResponse( + text="test", + audio=None, + images=[ImageOutput(url="https://fal.media/test.png")], + files=[], + ) + mock_generate = AsyncMock(return_value=mock_response) + + # Patch the instance attribute directly + mock_provider = MagicMock() + mock_provider.generate_image = mock_generate + ai._fal_provider_instance = mock_provider + + result = await ai.ai_with_vision( + prompt="A sunset", + model="fal-ai/flux/dev", + ) + + mock_generate.assert_called_once() + assert result.has_images + + @pytest.mark.asyncio + async def test_ai_with_vision_routes_fal_prefix(self, agent_with_ai, monkeypatch): + """ai_with_vision should route fal/ models to FalProvider.""" + ai = AgentAI(agent_with_ai) + + mock_response = MultimodalResponse( + text="test", + audio=None, + images=[ImageOutput(url="https://fal.media/test.png")], + files=[], + ) + mock_generate = AsyncMock(return_value=mock_response) + + # Patch the instance attribute directly + mock_provider = MagicMock() + mock_provider.generate_image = mock_generate + ai._fal_provider_instance = mock_provider + + await ai.ai_with_vision( + prompt="A sunset", + model="fal/flux-dev", + ) + + mock_generate.assert_called_once() + + @pytest.mark.asyncio + async def test_ai_with_audio_routes_fal_models(self, agent_with_ai): + """ai_with_audio should route fal-ai/ models to FalProvider.""" + ai = AgentAI(agent_with_ai) + + mock_response = MultimodalResponse( + text="Hello", + audio=AudioOutput(url="https://fal.media/audio.wav", data=None, format="wav"), + images=[], + files=[], + ) + mock_generate = AsyncMock(return_value=mock_response) + + # Patch the instance attribute directly + mock_provider = MagicMock() + mock_provider.generate_audio = mock_generate + ai._fal_provider_instance = mock_provider + + result = await ai.ai_with_audio( + "Hello world", + model="fal-ai/kokoro-tts", + ) + + mock_generate.assert_called_once() + assert result.has_audio + + +# ============================================================================= +# New Methods: ai_generate_video, ai_transcribe_audio +# ============================================================================= + + +class TestAIGenerateVideo: + """Tests for ai_generate_video method.""" + + @pytest.mark.asyncio + async def test_ai_generate_video_uses_default_model(self, agent_with_ai): + """ai_generate_video should use AIConfig.video_model as default.""" + ai = AgentAI(agent_with_ai) + + mock_response = MultimodalResponse( + text="", + audio=None, + images=[], + files=[FileOutput(url="https://fal.media/video.mp4", data=None, mime_type="video/mp4")], + ) + mock_generate = AsyncMock(return_value=mock_response) + + # Patch the instance attribute directly + mock_provider = MagicMock() + mock_provider.generate_video = mock_generate + ai._fal_provider_instance = mock_provider + + await ai.ai_generate_video(prompt="A cat playing") + + # Should use default video_model + mock_generate.assert_called_once() + call_kwargs = mock_generate.call_args[1] + assert call_kwargs["model"] == "fal-ai/minimax-video/image-to-video" + + @pytest.mark.asyncio + async def test_ai_generate_video_with_image_url(self, agent_with_ai): + """ai_generate_video should pass image_url for image-to-video.""" + ai = AgentAI(agent_with_ai) + + mock_response = MultimodalResponse( + text="", + audio=None, + images=[], + files=[FileOutput(url="https://fal.media/video.mp4", data=None, mime_type="video/mp4")], + ) + mock_generate = AsyncMock(return_value=mock_response) + + # Patch the instance attribute directly + mock_provider = MagicMock() + mock_provider.generate_video = mock_generate + ai._fal_provider_instance = mock_provider + + await ai.ai_generate_video( + prompt="Camera pans", + model="fal-ai/minimax-video/image-to-video", + image_url="https://example.com/image.jpg", + ) + + call_kwargs = mock_generate.call_args[1] + assert call_kwargs["image_url"] == "https://example.com/image.jpg" + + @pytest.mark.asyncio + async def test_ai_generate_video_rejects_non_fal_models(self, agent_with_ai): + """ai_generate_video should reject non-Fal models.""" + ai = AgentAI(agent_with_ai) + + with pytest.raises(ValueError, match="only supports Fal.ai models"): + await ai.ai_generate_video( + prompt="A cat", + model="openai/video-model", + ) + + +class TestAITranscribeAudio: + """Tests for ai_transcribe_audio method.""" + + @pytest.mark.asyncio + async def test_ai_transcribe_audio_default_model(self, agent_with_ai): + """ai_transcribe_audio should default to fal-ai/whisper.""" + ai = AgentAI(agent_with_ai) + + mock_response = MultimodalResponse( + text="Hello world", + audio=None, + images=[], + files=[], + ) + mock_transcribe = AsyncMock(return_value=mock_response) + + # Patch the instance attribute directly + mock_provider = MagicMock() + mock_provider.transcribe_audio = mock_transcribe + ai._fal_provider_instance = mock_provider + + result = await ai.ai_transcribe_audio( + audio_url="https://example.com/audio.mp3" + ) + + call_kwargs = mock_transcribe.call_args[1] + assert call_kwargs["model"] == "fal-ai/whisper" + assert result.text == "Hello world" + + @pytest.mark.asyncio + async def test_ai_transcribe_audio_with_language(self, agent_with_ai): + """ai_transcribe_audio should pass language hint.""" + ai = AgentAI(agent_with_ai) + + mock_response = MultimodalResponse(text="Hola mundo", audio=None, images=[], files=[]) + mock_transcribe = AsyncMock(return_value=mock_response) + + # Patch the instance attribute directly + mock_provider = MagicMock() + mock_provider.transcribe_audio = mock_transcribe + ai._fal_provider_instance = mock_provider + + await ai.ai_transcribe_audio( + audio_url="https://example.com/spanish.mp3", + model="fal-ai/whisper", + language="es", + ) + + call_kwargs = mock_transcribe.call_args[1] + assert call_kwargs["language"] == "es" + + @pytest.mark.asyncio + async def test_ai_transcribe_audio_rejects_non_fal_models(self, agent_with_ai): + """ai_transcribe_audio should reject non-Fal models.""" + ai = AgentAI(agent_with_ai) + + with pytest.raises(ValueError, match="only supports Fal.ai models"): + await ai.ai_transcribe_audio( + audio_url="https://example.com/audio.mp3", + model="openai/whisper", + ) + + +# ============================================================================= +# Integration-style Tests +# ============================================================================= + + +class TestUnifiedMultimodalUX: + """Integration tests for unified multimodal UX pattern.""" + + @pytest.mark.asyncio + async def test_image_generation_routes_correctly(self, agent_with_ai, monkeypatch): + """Different model prefixes should route to correct providers.""" + ai = AgentAI(agent_with_ai) + + # Track which methods are called + calls = [] + + async def mock_fal_generate(*args, **kwargs): + calls.append(("fal", kwargs.get("model"))) + return MultimodalResponse( + text="", audio=None, + images=[ImageOutput(url="https://fal.media/img.png")], + files=[], + ) + + # Setup mocks - patch the instance attribute directly + mock_provider = MagicMock() + mock_provider.generate_image = mock_fal_generate + ai._fal_provider_instance = mock_provider + + # Test fal-ai/ prefix + await ai.ai_with_vision(prompt="test", model="fal-ai/flux/dev") + assert ("fal", "fal-ai/flux/dev") in calls + + # Test fal/ prefix + calls.clear() + await ai.ai_with_vision(prompt="test", model="fal/recraft-v3") + assert ("fal", "fal/recraft-v3") in calls + + def test_all_new_methods_exist(self, agent_with_ai): + """Agent should have all new multimodal methods.""" + ai = AgentAI(agent_with_ai) + + # Check methods exist + assert hasattr(ai, "ai_generate_video") + assert hasattr(ai, "ai_transcribe_audio") + assert hasattr(ai, "_fal_provider") + + # Check they're callable + assert callable(ai.ai_generate_video) + assert callable(ai.ai_transcribe_audio)