diff --git a/service/app/mcp/literature.py b/service/app/mcp/literature.py new file mode 100644 index 00000000..57dbb00f --- /dev/null +++ b/service/app/mcp/literature.py @@ -0,0 +1,376 @@ +""" +Literature MCP Server - Multi-source academic literature search + +Provides tools for searching academic literature from multiple data sources +(OpenAlex, Semantic Scholar, PubMed, etc.) with unified interface. +""" + +import json +import logging +from datetime import datetime +from typing import Any + +import httpx +from fastmcp import FastMCP + +from app.utils.literature import SearchRequest, WorkDistributor + +logger = logging.getLogger(__name__) + +# Create FastMCP instance +mcp = FastMCP("literature") + +# Metadata for MCP server +__mcp_metadata__ = { + "name": "Literature Search", + "description": "Search academic literature from multiple sources with advanced filtering", + "version": "1.0.0", +} + + +@mcp.tool() +async def search_literature( + query: str, + mailto: str | None = None, + author: str | None = None, + institution: str | None = None, + source: str | None = None, + year_from: str | None = None, + year_to: str | None = None, + is_oa: str | None = None, + work_type: str | None = None, + language: str | None = None, + is_retracted: str | None = None, + has_abstract: str | None = None, + has_fulltext: str | None = None, + sort_by: str = "relevance", + max_results: str | int = 50, + data_sources: list[str] | None = None, + include_abstract: str | bool = False, +) -> str: + """ + Search academic literature from multiple data sources (OpenAlex, etc.) + + ⚠️ IMPORTANT: A valid email address (mailto parameter) enables the OpenAlex polite pool + (10 req/s). If omitted, the default pool is used (1 req/s, sequential). Production + usage should provide an email. + + Basic usage: Provide query keywords and user's email. Returns a Markdown report + with statistics and JSON list of papers. + + Args: + query: Search keywords (e.g., "machine learning", "CRISPR") + mailto: OPTIONAL - User's email (e.g., "researcher@university.edu") + author: OPTIONAL - Author name (e.g., "Albert Einstein") + institution: OPTIONAL - Institution (e.g., "MIT", "Harvard University") + source: OPTIONAL - Journal (e.g., "Nature", "Science") + year_from: OPTIONAL - Start year (e.g., "2020" or 2020) + year_to: OPTIONAL - End year (e.g., "2024" or 2024) + is_oa: OPTIONAL - Open access only ("true"/"false") + work_type: OPTIONAL - Work type: "article", "review", "preprint", "book", "dissertation", etc. + language: OPTIONAL - Language code (e.g., "en" for English, "zh" for Chinese, "fr" for French) + is_retracted: OPTIONAL - Filter retracted works ("true" to include only retracted, "false" to exclude) + has_abstract: OPTIONAL - Require abstract ("true" to include only works with abstracts) + has_fulltext: OPTIONAL - Require full text ("true" to include only works with full text) + sort_by: Sort: "relevance" (default), "cited_by_count", "publication_date" + max_results: Max papers (default: 50, range: 1-200, accepts string or int) + data_sources: Sources to search (default: ["openalex"]) + include_abstract: Include abstracts (default: False, accepts string or bool) + + Returns: + Markdown report with: + - Warnings if filters fail + - Statistics (citations, open access rate) + - JSON list of papers (title, authors, DOI, etc.) + - Next steps guidance + + Usage tips: + - START SIMPLE: just query + mailto + - Tool will suggest corrections if author/institution not found + - Review "Next Steps Guide" before searching again + + Examples: + # Minimal (recommended) + search_literature("machine learning", mailto="researcher@uni.edu") + + # With filters (accepts both strings and integers) + search_literature( + query="CRISPR", + mailto="researcher@uni.edu", + author="Jennifer Doudna", + year_from="2020", + year_to="2024" + ) + + # Recent reviews (past 5 years, English only) + search_literature( + query="cancer immunotherapy", + mailto="user@example.com", + work_type="review", + language="en", + year_from="2020", + sort_by="cited_by_count" + ) + + # Research articles with abstracts (exclude retracted) + search_literature( + query="CRISPR gene editing", + mailto="user@example.com", + work_type="article", + has_abstract="true", + is_retracted="false" + ) + """ + try: + # Convert string parameters to proper types + year_from_int = int(year_from) if year_from and str(year_from).strip() else None + year_to_int = int(year_to) if year_to and str(year_to).strip() else None + + # Clamp year ranges (warn but don't block search) + max_year = datetime.now().year + 1 + year_warning = "" + if year_from_int is not None and year_from_int > max_year: + year_warning += f"year_from {year_from_int}→{max_year}. " + year_from_int = max_year + if year_to_int is not None and year_to_int < 1700: + year_warning += f"year_to {year_to_int}→1700. " + year_to_int = 1700 + + # Convert is_oa to boolean + is_oa_bool: bool | None = None + if is_oa is not None: + is_oa_bool = str(is_oa).lower() in ("true", "1", "yes") + + # Convert is_retracted to boolean + is_retracted_bool: bool | None = None + if is_retracted is not None: + is_retracted_bool = str(is_retracted).lower() in ("true", "1", "yes") + + # Convert has_abstract to boolean + has_abstract_bool: bool | None = None + if has_abstract is not None: + has_abstract_bool = str(has_abstract).lower() in ("true", "1", "yes") + + # Convert has_fulltext to boolean + has_fulltext_bool: bool | None = None + if has_fulltext is not None: + has_fulltext_bool = str(has_fulltext).lower() in ("true", "1", "yes") + + # Convert max_results to int + max_results_int = int(max_results) if max_results else 50 + + # Convert include_abstract to bool + include_abstract_bool = str(include_abstract).lower() in ("true", "1", "yes") if include_abstract else False + + openalex_email = mailto.strip() if mailto and str(mailto).strip() else None + + logger.info( + f"Literature search requested: query='{query}', mailto={openalex_email}, max_results={max_results_int}" + ) + + # Create search request with converted types + request = SearchRequest( + query=query, + author=author, + institution=institution, + source=source, + year_from=year_from_int, + year_to=year_to_int, + is_oa=is_oa_bool, + work_type=work_type, + language=language, + is_retracted=is_retracted_bool, + has_abstract=has_abstract_bool, + has_fulltext=has_fulltext_bool, + sort_by=sort_by, + max_results=max_results_int, + data_sources=data_sources, + ) + + # Execute search + async with WorkDistributor(openalex_email=openalex_email) as distributor: + result = await distributor.search(request) + + if year_warning: + result.setdefault("warnings", []).append(f"⚠️ Year adjusted: {year_warning.strip()}") + + # Format output + return _format_search_result(request, result, include_abstract_bool) + + except ValueError as e: + logger.warning(f"Literature search validation error: {e}") + return f"❌ Invalid input: {str(e)}" + except httpx.HTTPError as e: + logger.error(f"Literature search network error: {e}", exc_info=True) + return "❌ Network error while contacting literature sources. Please try again later." + except Exception as e: + logger.error(f"Literature search failed: {e}", exc_info=True) + return "❌ Unexpected error during search. Please retry or contact support." + + +def _format_search_result(request: SearchRequest, result: dict[str, Any], include_abstract: bool = False) -> str: + """ + Format search results into human-readable report + JSON data + + Args: + request: Original search request + result: Search result from WorkDistributor + include_abstract: Whether to include abstracts in JSON (default: False to save tokens) + + Returns: + Formatted markdown report with embedded JSON + """ + works = result["works"] + total_count = result["total_count"] + unique_count = result["unique_count"] + sources = result["sources"] + warnings = result.get("warnings", []) + + # Build report sections + sections: list[str] = [] + + # Header + sections.append("# Literature Search Report\n") + + # Warnings and resolution status (if any) + if warnings: + sections.append("## ⚠️ Warnings and Resolution Status\n") + for warning in warnings: + sections.append(f"{warning}") + sections.append("") + + # Search conditions + sections.append("## Search Conditions\n") + conditions: list[str] = [] + conditions.append(f"- **Query**: {request.query}") + if request.author: + conditions.append(f"- **Author**: {request.author}") + if request.institution: + conditions.append(f"- **Institution**: {request.institution}") + if request.source: + conditions.append(f"- **Source**: {request.source}") + if request.year_from or request.year_to: + year_range = f"{request.year_from or '...'} - {request.year_to or '...'}" + conditions.append(f"- **Year Range**: {year_range}") + if request.is_oa is not None: + conditions.append(f"- **Open Access Only**: {'Yes' if request.is_oa else 'No'}") + if request.work_type: + conditions.append(f"- **Work Type**: {request.work_type}") + if request.language: + conditions.append(f"- **Language**: {request.language}") + if request.is_retracted is not None: + conditions.append(f"- **Exclude Retracted**: {'No' if request.is_retracted else 'Yes'}") + if request.has_abstract is not None: + conditions.append(f"- **Require Abstract**: {'Yes' if request.has_abstract else 'No'}") + if request.has_fulltext is not None: + conditions.append(f"- **Require Full Text**: {'Yes' if request.has_fulltext else 'No'}") + conditions.append(f"- **Sort By**: {request.sort_by}") + conditions.append(f"- **Max Results**: {request.max_results}") + sections.append("\n".join(conditions)) + sections.append("") + + # Check if no results + if not works: + sections.append("## ❌ No Results Found\n") + sections.append("**Suggestions to improve your search:**\n") + suggestions: list[str] = [] + suggestions.append("1. **Simplify keywords**: Try broader or different terms") + if request.author: + suggestions.append("2. **Remove author filter**: Author name may not be recognized") + if request.institution: + suggestions.append("3. **Remove institution filter**: Try without institution constraint") + if request.source: + suggestions.append("4. **Remove source filter**: Try without journal constraint") + if request.year_from or request.year_to: + suggestions.append("5. **Expand year range**: Current range may be too narrow") + if request.is_oa: + suggestions.append("6. **Remove open access filter**: Include non-OA papers") + suggestions.append("7. **Check spelling**: Verify all terms are spelled correctly") + sections.append("\n".join(suggestions)) + sections.append("") + return "\n".join(sections) + + # Statistics and overall insights + sections.append("## Search Statistics\n") + stats: list[str] = [] + stats.append(f"- **Total Found**: {total_count} works") + stats.append(f"- **After Deduplication**: {unique_count} works") + source_info = ", ".join(f"{name}: {count}" for name, count in sources.items()) + stats.append(f"- **Data Sources**: {source_info}") + + # Add insights + if works: + avg_citations = sum(w.cited_by_count for w in works) / len(works) + stats.append(f"- **Average Citations**: {avg_citations:.1f}") + + oa_count = sum(1 for w in works if w.is_oa) + oa_ratio = (oa_count / len(works)) * 100 + stats.append(f"- **Open Access Rate**: {oa_ratio:.1f}% ({oa_count}/{len(works)})") + + years = [w.publication_year for w in works if w.publication_year] + if years: + stats.append(f"- **Year Range**: {min(years)} - {max(years)}") + + sections.append("\n".join(stats)) + sections.append("") + + # Complete JSON list + sections.append("## Complete Works List (JSON)\n") + if include_abstract: + sections.append("The following JSON contains all works with full abstracts:\n") + else: + sections.append("The following JSON contains all works (abstracts excluded to save tokens):\n") + sections.append("```json") + + # Convert works to dict for JSON serialization + works_dict = [] + for work in works: + work_data = { + "id": work.id, + "doi": work.doi, + "title": work.title, + "authors": work.authors[:5], # Limit to first 5 authors + "publication_year": work.publication_year, + "cited_by_count": work.cited_by_count, + "journal": work.journal, + "is_oa": work.is_oa, + "oa_url": work.oa_url, + "source": work.source, + } + # Only include abstract if requested + if include_abstract and work.abstract: + work_data["abstract"] = work.abstract + works_dict.append(work_data) + + sections.append(json.dumps(works_dict, indent=2, ensure_ascii=False)) + sections.append("```") + sections.append("") + + # Next steps guidance - prevent infinite loops + sections.append("---") + sections.append("## 🎯 Next Steps Guide\n") + sections.append("**Before making another search, consider:**\n") + next_steps: list[str] = [] + + if unique_count > 0: + next_steps.append("✓ **Results found** - Review the JSON data above for your analysis") + if unique_count >= request.max_results: + next_steps.append( + f"⚠️ **Result limit reached** ({request.max_results}) - " + "Consider narrowing filters (author, year, journal) for more targeted results" + ) + if unique_count < 10: + next_steps.append("💡 **Few results** - Consider broadening your search by removing some filters") + + next_steps.append("") + next_steps.append("**To refine your search:**") + next_steps.append("- If too many results → Add more specific filters (author, institution, journal, year)") + next_steps.append("- If too few results → Remove filters or use broader keywords") + next_steps.append("- If wrong results → Check filter spelling and try variations") + next_steps.append("") + next_steps.append("⚠️ **Important**: Avoid making multiple similar searches without reviewing results first!") + next_steps.append("Each search consumes API quota and context window. Make targeted, deliberate queries.") + + sections.append("\n".join(next_steps)) + + return "\n".join(sections) diff --git a/service/app/utils/literature/__init__.py b/service/app/utils/literature/__init__.py new file mode 100644 index 00000000..c4dd14ba --- /dev/null +++ b/service/app/utils/literature/__init__.py @@ -0,0 +1,17 @@ +""" +Literature search utilities for multi-source academic literature retrieval +""" + +from .base_client import BaseLiteratureClient +from .doi_cleaner import deduplicate_by_doi, normalize_doi +from .models import LiteratureWork, SearchRequest +from .work_distributor import WorkDistributor + +__all__ = [ + "BaseLiteratureClient", + "normalize_doi", + "deduplicate_by_doi", + "SearchRequest", + "LiteratureWork", + "WorkDistributor", +] diff --git a/service/app/utils/literature/base_client.py b/service/app/utils/literature/base_client.py new file mode 100644 index 00000000..ba8a3db6 --- /dev/null +++ b/service/app/utils/literature/base_client.py @@ -0,0 +1,32 @@ +""" +Abstract base class for literature data source clients +""" + +from abc import ABC, abstractmethod + +from .models import LiteratureWork, SearchRequest + + +class BaseLiteratureClient(ABC): + """ + Base class for literature data source clients + + All data source implementations (OpenAlex, Semantic Scholar, PubMed, etc.) + should inherit from this class and implement the required methods. + """ + + @abstractmethod + async def search(self, request: SearchRequest) -> tuple[list[LiteratureWork], list[str]]: + """ + Execute search and return results in standard format + + Args: + request: Standardized search request + + Returns: + Tuple of (works, warnings) where warnings is a list of messages for LLM feedback + + Raises: + Exception: If search fails after retries + """ + pass diff --git a/service/app/utils/literature/doi_cleaner.py b/service/app/utils/literature/doi_cleaner.py new file mode 100644 index 00000000..816e20e3 --- /dev/null +++ b/service/app/utils/literature/doi_cleaner.py @@ -0,0 +1,121 @@ +""" +DOI normalization and deduplication utilities +""" + +import re +from typing import Protocol, TypeVar + + +class WorkWithDOI(Protocol): + """Protocol for objects with DOI and citation information""" + + doi: str | None + cited_by_count: int + publication_year: int | None + + +T = TypeVar("T", bound=WorkWithDOI) + + +def normalize_doi(doi: str | None) -> str | None: + """ + Normalize DOI format to standard form + + Removes common prefixes, validates format, and converts to lowercase. + DOI specification (ISO 26324) defines DOI matching as case-insensitive, + so lowercase conversion is safe and improves consistency. + + Args: + doi: DOI string in any common format + + Returns: + Normalized DOI (e.g., "10.1038/nature12345") or None if invalid + + Examples: + >>> normalize_doi("https://doi.org/10.1038/nature12345") + "10.1038/nature12345" + >>> normalize_doi("DOI: 10.1038/nature12345") + "10.1038/nature12345" + >>> normalize_doi("doi:10.1038/nature12345") + "10.1038/nature12345" + """ + if not doi: + return None + + doi = doi.strip().lower() + + # Remove common prefixes + doi = re.sub(r"^(https?://)?(dx\.)?doi\.org/", "", doi) + doi = re.sub(r"^doi:\s*", "", doi) + + # Validate format (10.xxxx/yyyy) + if not re.match(r"^10\.\d+/.+", doi): + return None + + return doi + + +def deduplicate_by_doi(works: list[T]) -> list[T]: + """ + Deduplicate works by DOI, keeping the highest priority version + + Priority rules: + 1. Works with DOI take priority over those without + 2. For same DOI, keep the one with higher citation count + 3. If citation count is equal, keep the most recently published + + Args: + works: List of LiteratureWork objects + + Returns: + Deduplicated list of works + + Examples: + >>> works = [ + ... LiteratureWork(doi="10.1038/1", cited_by_count=100, ...), + ... LiteratureWork(doi="10.1038/1", cited_by_count=50, ...), + ... LiteratureWork(doi=None, ...), + ... ] + >>> unique = deduplicate_by_doi(works) + >>> len(unique) + 2 + >>> unique[0].cited_by_count + 100 + """ + # Group by: with DOI vs without DOI + with_doi: dict[str, T] = {} + without_doi: list[T] = [] + + for work in works: + # Check if work has doi attribute + if not work.doi: + without_doi.append(work) + continue + + doi = normalize_doi(work.doi) + if not doi: + without_doi.append(work) + continue + + # If DOI already exists, compare priority + if doi in with_doi: + existing = with_doi[doi] + + # Higher citation count? + if work.cited_by_count > existing.cited_by_count: + with_doi[doi] = work + # Same citation count, more recent publication? + elif ( + work.cited_by_count == existing.cited_by_count + and work.publication_year + and existing.publication_year + and work.publication_year > existing.publication_year + ): + with_doi[doi] = work + else: + with_doi[doi] = work + + # Combine results: DOI works first, then non-DOI works + unique_works = list(with_doi.values()) + without_doi + + return unique_works diff --git a/service/app/utils/literature/models.py b/service/app/utils/literature/models.py new file mode 100644 index 00000000..21e35c35 --- /dev/null +++ b/service/app/utils/literature/models.py @@ -0,0 +1,80 @@ +""" +Shared data models for literature utilities +""" + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class SearchRequest: + """ + Standardized search request format for all data sources + + Attributes: + query: Search keywords (searches title, abstract, full text) + author: Author name (will be converted to author ID) + institution: Institution name (will be converted to institution ID) + source: Journal or conference name + year_from: Start year (inclusive) + year_to: End year (inclusive) + is_oa: Filter for open access only + work_type: Work type filter ("article", "review", "preprint", etc.) + language: Language code filter (e.g., "en", "zh", "fr") + is_retracted: Filter for retracted works (True to include only retracted, False to exclude) + has_abstract: Filter for works with abstracts + has_fulltext: Filter for works with full text available + sort_by: Sort method - "relevance", "cited_by_count", "publication_date" + max_results: Maximum number of results to return + data_sources: List of data sources to query (default: ["openalex"]) + """ + + query: str + author: str | None = None + institution: str | None = None + source: str | None = None + year_from: int | None = None + year_to: int | None = None + is_oa: bool | None = None + work_type: str | None = None + language: str | None = None + is_retracted: bool | None = None + has_abstract: bool | None = None + has_fulltext: bool | None = None + sort_by: str = "relevance" + max_results: int = 50 + data_sources: list[str] | None = None + + +@dataclass +class LiteratureWork: + """ + Standardized literature work format across all data sources + + Attributes: + id: Internal ID from the data source + doi: Digital Object Identifier (normalized format) + title: Work title + authors: List of author information [{"name": "...", "id": "..."}] + publication_year: Year of publication + cited_by_count: Number of citations + abstract: Abstract text + journal: Journal or venue name + is_oa: Whether open access + oa_url: URL to open access version + source: Data source name ("openalex", "semantic_scholar", etc.) + raw_data: Original data from the source (for debugging) + """ + + id: str + doi: str | None + title: str + authors: list[dict[str, str | None]] + publication_year: int | None + cited_by_count: int + abstract: str | None + journal: str | None + is_oa: bool + oa_url: str | None + source: str + raw_data: dict[str, Any] = field(default_factory=dict) diff --git a/service/app/utils/literature/openalex_client.py b/service/app/utils/literature/openalex_client.py new file mode 100644 index 00000000..089fe5c8 --- /dev/null +++ b/service/app/utils/literature/openalex_client.py @@ -0,0 +1,559 @@ +""" +OpenAlex API client for literature search + +Implements the best practices from OpenAlex API guide: +- Two-step lookup for names (author/institution/source -> ID -> filter) +- Rate limiting with mailto parameter (10 req/s) +- Exponential backoff retry for errors +- Batch queries with pipe separator (up to 50 IDs) +- Maximum page size (200 per page) +- Abstract reconstruction from inverted index +""" + +import asyncio +import logging +from typing import Any + +import httpx + +from .base_client import BaseLiteratureClient +from .models import LiteratureWork, SearchRequest + +logger = logging.getLogger(__name__) + + +class _RateLimiter: + """ + Simple global rate limiter with optional concurrency guard. + + Enforces a minimum interval between request starts across all callers. + """ + + def __init__(self, rate_per_second: float, max_concurrency: int) -> None: + self._min_interval = 1.0 / rate_per_second if rate_per_second > 0 else 0.0 + self._lock = asyncio.Lock() + self._last_request = 0.0 + self._semaphore = asyncio.Semaphore(max_concurrency) + + async def __aenter__(self) -> None: + await self._semaphore.acquire() + await self._throttle() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: Any | None, + ) -> None: + self._semaphore.release() + + async def _throttle(self) -> None: + if self._min_interval <= 0: + return + + async with self._lock: + now = asyncio.get_running_loop().time() + wait_time = self._last_request + self._min_interval - now + if wait_time > 0: + await asyncio.sleep(wait_time) + self._last_request = asyncio.get_running_loop().time() + + +class OpenAlexClient(BaseLiteratureClient): + """ + OpenAlex API client + + Implements best practices from official API guide for LLMs: + https://docs.openalex.org/api-guide-for-llms + """ + + BASE_URL = "https://api.openalex.org" + MAX_PER_PAGE = 200 + MAX_RETRIES = 5 + TIMEOUT = 30.0 + + def __init__(self, email: str | None, rate_limit: int | None = None, timeout: float = 30.0) -> None: + """ + Initialize OpenAlex client + + Args: + email: Email for polite pool (10x rate limit increase). If None, use default pool. + rate_limit: Requests per second (default: 10 with email, 1 without email) + timeout: Request timeout in seconds (default: 30.0) + """ + self.email = email + self.rate_limit = rate_limit or (10 if self.email else 1) + max_concurrency = 10 if self.email else 1 + self.rate_limiter = _RateLimiter(rate_per_second=self.rate_limit, max_concurrency=max_concurrency) + self.client = httpx.AsyncClient(timeout=timeout) + pool_type = "polite" if self.email else "default" + logger.info( + f"OpenAlex client initialized with pool={pool_type}, email={self.email}, rate_limit={self.rate_limit}/s" + ) + + @property + def pool_type(self) -> str: + """Return pool type string.""" + return "polite" if self.email else "default" + + async def search(self, request: SearchRequest) -> tuple[list[LiteratureWork], list[str]]: + """ + Execute search and return results in standard format + + Implementation steps: + 1. Convert author name -> author ID (if specified) + 2. Convert institution name -> institution ID (if specified) + 3. Convert journal name -> source ID (if specified) + 4. Build filter query + 5. Paginate through results + 6. Transform to standard format + + Args: + request: Standardized search request + + Returns: + Tuple of (works, warnings) + - works: List of literature works in standard format + - warnings: List of warning/info messages for LLM feedback + """ + logger.info( + f"OpenAlex search [{self.pool_type} @ {self.rate_limit}/s]: query='{request.query}', max_results={request.max_results}" + ) + + warnings: list[str] = [] + + # Step 1-3: Resolve IDs for names (two-step lookup pattern) + author_id = None + if request.author: + author_id, _success, msg = await self._resolve_author_id(request.author) + warnings.append(msg) + + institution_id = None + if request.institution: + institution_id, _success, msg = await self._resolve_institution_id(request.institution) + warnings.append(msg) + + source_id = None + if request.source: + source_id, _success, msg = await self._resolve_source_id(request.source) + warnings.append(msg) + + # Step 4: Build query parameters + params = self._build_query_params(request, author_id, institution_id, source_id) + + # Step 5: Fetch all pages + works = await self._fetch_all_pages(params, request.max_results) + + # Step 6: Transform to standard format + return [self._transform_work(w) for w in works], warnings + + def _build_query_params( + self, + request: SearchRequest, + author_id: str | None, + institution_id: str | None, + source_id: str | None, + ) -> dict[str, str]: + """ + Build OpenAlex query parameters + + Args: + request: Search request + author_id: Resolved author ID (if any) + institution_id: Resolved institution ID (if any) + source_id: Resolved source ID (if any) + + Returns: + Dictionary of query parameters + """ + params: dict[str, str] = { + "per-page": str(self.MAX_PER_PAGE), + } + + if self.email: + params["mailto"] = self.email + + # Search keywords + if request.query: + params["search"] = request.query + + # Build filters + filters: list[str] = [] + + if author_id: + filters.append(f"authorships.author.id:{author_id}") + + if institution_id: + filters.append(f"authorships.institutions.id:{institution_id}") + + if source_id: + filters.append(f"primary_location.source.id:{source_id}") + + # Year range + if request.year_from or request.year_to: + if request.year_from and request.year_to: + filters.append(f"publication_year:{request.year_from}-{request.year_to}") + elif request.year_from: + filters.append(f"publication_year:>{request.year_from - 1}") + elif request.year_to: + filters.append(f"publication_year:<{request.year_to + 1}") + + # Open access filter + if request.is_oa is not None: + filters.append(f"is_oa:{str(request.is_oa).lower()}") + + # Work type filter + if request.work_type: + filters.append(f"type:{request.work_type}") + + # Language filter + if request.language: + filters.append(f"language:{request.language}") + + # Retracted filter + if request.is_retracted is not None: + filters.append(f"is_retracted:{str(request.is_retracted).lower()}") + + # Abstract filter + if request.has_abstract is not None: + filters.append(f"has_abstract:{str(request.has_abstract).lower()}") + + # Fulltext filter + if request.has_fulltext is not None: + filters.append(f"has_fulltext:{str(request.has_fulltext).lower()}") + + if filters: + params["filter"] = ",".join(filters) + + # Sorting + sort_map = { + "relevance": None, # Default sorting by relevance + "cited_by_count": "cited_by_count:desc", + "publication_date": "publication_date:desc", + } + if sort := sort_map.get(request.sort_by): + params["sort"] = sort + + return params + + async def _resolve_author_id(self, author_name: str) -> tuple[str | None, bool, str]: + """ + Two-step lookup: author name -> author ID + + Args: + author_name: Author name to search + + Returns: + Tuple of (author_id, success, message) + - author_id: Author ID (e.g., "A5023888391") or None if not found + - success: Whether resolution was successful + - message: Status message for LLM feedback + """ + async with self.rate_limiter: + try: + url = f"{self.BASE_URL}/authors" + params: dict[str, str] = {"search": author_name} + if self.email: + params["mailto"] = self.email + response = await self._request_with_retry(url, params) + + if results := response.get("results", []): + # Return first result's ID in short format + author_id = results[0]["id"].split("/")[-1] + author_display = results[0].get("display_name", author_name) + logger.info(f"Resolved author '{author_name}' -> {author_id}") + return author_id, True, f"✓ Author resolved: '{author_name}' -> '{author_display}'" + else: + msg = ( + f"⚠️ Author '{author_name}' not found. " + f"Suggestions: (1) Try full name format like 'Smith, John' or 'John Smith', " + f"(2) Check spelling, (3) Try removing middle name/initial." + ) + logger.warning(msg) + return None, False, msg + except Exception as e: + msg = f"⚠️ Failed to resolve author '{author_name}': {e}" + logger.warning(msg) + return None, False, msg + + async def _resolve_institution_id(self, institution_name: str) -> tuple[str | None, bool, str]: + """ + Two-step lookup: institution name -> institution ID + + Args: + institution_name: Institution name to search + + Returns: + Tuple of (institution_id, success, message) + - institution_id: Institution ID (e.g., "I136199984") or None if not found + - success: Whether resolution was successful + - message: Status message for LLM feedback + """ + async with self.rate_limiter: + try: + url = f"{self.BASE_URL}/institutions" + params: dict[str, str] = {"search": institution_name} + if self.email: + params["mailto"] = self.email + response = await self._request_with_retry(url, params) + + if results := response.get("results", []): + institution_id = results[0]["id"].split("/")[-1] + inst_display = results[0].get("display_name", institution_name) + logger.info(f"Resolved institution '{institution_name}' -> {institution_id}") + return institution_id, True, f"✓ Institution resolved: '{institution_name}' -> '{inst_display}'" + else: + msg = ( + f"⚠️ Institution '{institution_name}' not found. " + f"Suggestions: (1) Use full official name (e.g., 'Harvard University' not 'Harvard'), " + f"(2) Try variations (e.g., 'MIT' vs 'Massachusetts Institute of Technology'), " + f"(3) Check spelling." + ) + logger.warning(msg) + return None, False, msg + except Exception as e: + msg = f"⚠️ Failed to resolve institution '{institution_name}': {e}" + logger.warning(msg) + return None, False, msg + + async def _resolve_source_id(self, source_name: str) -> tuple[str | None, bool, str]: + """ + Two-step lookup: source name -> source ID + + Args: + source_name: Journal/conference name to search + + Returns: + Tuple of (source_id, success, message) + - source_id: Source ID (e.g., "S137773608") or None if not found + - success: Whether resolution was successful + - message: Status message for LLM feedback + """ + async with self.rate_limiter: + try: + url = f"{self.BASE_URL}/sources" + params: dict[str, str] = {"search": source_name} + if self.email: + params["mailto"] = self.email + response = await self._request_with_retry(url, params) + + if results := response.get("results", []): + source_id = results[0]["id"].split("/")[-1] + source_display = results[0].get("display_name", source_name) + logger.info(f"Resolved source '{source_name}' -> {source_id}") + return source_id, True, f"✓ Source resolved: '{source_name}' -> '{source_display}'" + else: + msg = ( + f"⚠️ Source/Journal '{source_name}' not found. " + f"Suggestions: (1) Use full journal name (e.g., 'Nature' or 'Science'), " + f"(2) Try alternative names (e.g., 'JAMA' vs 'Journal of the American Medical Association'), " + f"(3) Check spelling." + ) + logger.warning(msg) + return None, False, msg + except Exception as e: + msg = f"⚠️ Failed to resolve source '{source_name}': {e}" + logger.warning(msg) + return None, False, msg + + async def _fetch_all_pages(self, params: dict[str, str], max_results: int) -> list[dict[str, Any]]: + """ + Paginate through all results up to max_results + + Args: + params: Base query parameters + max_results: Maximum number of results to fetch + + Returns: + List of work objects from API + """ + all_works: list[dict[str, Any]] = [] + page = 1 + + while len(all_works) < max_results: + async with self.rate_limiter: + try: + url = f"{self.BASE_URL}/works" + page_params = {**params, "page": str(page)} + response = await self._request_with_retry(url, page_params) + + works = response.get("results", []) + if not works: + break + + all_works.extend(works) + logger.info(f"Fetched page {page}: {len(works)} works") + + # Check if there are more pages + meta = response.get("meta", {}) + total_count = meta.get("count", 0) + if len(all_works) >= total_count: + break + + page += 1 + + except Exception as e: + logger.error(f"Error fetching page {page}: {e}") + break + + return all_works[:max_results] + + async def _request_with_retry(self, url: str, params: dict[str, str]) -> dict[str, Any]: + """ + HTTP request with exponential backoff retry + + Implements best practices: + - Retry on 403 (rate limit) with exponential backoff + - Retry on 5xx (server error) with exponential backoff + - Don't retry on 4xx (except 403) + - Retry on timeout + + Args: + url: Request URL + params: Query parameters + + Returns: + JSON response + + Raises: + Exception: If all retries fail + """ + for attempt in range(self.MAX_RETRIES): + try: + response = await self.client.get(url, params=params) + + if response.status_code == 200: + return response.json() + elif response.status_code == 403: + # Rate limited + wait_time = 2**attempt + logger.warning(f"Rate limited (403), waiting {wait_time}s... (attempt {attempt + 1})") + await asyncio.sleep(wait_time) + elif response.status_code >= 500: + # Server error + wait_time = 2**attempt + logger.warning( + f"Server error ({response.status_code}), waiting {wait_time}s... (attempt {attempt + 1})" + ) + await asyncio.sleep(wait_time) + else: + # Other error, don't retry + response.raise_for_status() + + except httpx.TimeoutException: + if attempt < self.MAX_RETRIES - 1: + wait_time = 2**attempt + logger.warning(f"Timeout, retrying in {wait_time}s... (attempt {attempt + 1})") + await asyncio.sleep(wait_time) + else: + raise + except Exception as e: + logger.error(f"Request failed: {e}") + if attempt < self.MAX_RETRIES - 1: + wait_time = 2**attempt + await asyncio.sleep(wait_time) + else: + raise + + raise Exception(f"Failed after {self.MAX_RETRIES} retries") + + def _transform_work(self, work: dict[str, Any]) -> LiteratureWork: + """ + Transform OpenAlex work data to standard format + + Args: + work: Raw work object from OpenAlex API + + Returns: + Standardized LiteratureWork object + """ + # Extract authors + authors: list[dict[str, str | None]] = [] + for authorship in work.get("authorships", []): + author = authorship.get("author", {}) + authors.append( + { + "name": author.get("display_name", "Unknown"), + "id": author.get("id", "").split("/")[-1] if author.get("id") else None, + } + ) + + # Extract journal/source + journal = None + if primary_location := work.get("primary_location"): + if source := primary_location.get("source"): + journal = source.get("display_name") + + # Extract open access info + oa_info = work.get("open_access", {}) + is_oa = oa_info.get("is_oa", False) + oa_url = oa_info.get("oa_url") + + # Extract abstract (reconstruct from inverted index) + abstract = self._reconstruct_abstract(work.get("abstract_inverted_index")) + + # Extract DOI (remove prefix) + doi = None + if doi_raw := work.get("doi"): + doi = doi_raw.replace("https://doi.org/", "") + + return LiteratureWork( + id=work["id"].split("/")[-1], + doi=doi, + title=work.get("title", "Untitled"), + authors=authors, + publication_year=work.get("publication_year"), + cited_by_count=work.get("cited_by_count", 0), + abstract=abstract, + journal=journal, + is_oa=is_oa, + oa_url=oa_url, + source="openalex", + raw_data=work, + ) + + def _reconstruct_abstract(self, inverted_index: dict[str, list[int]] | None) -> str | None: + """ + Reconstruct abstract from inverted index + + OpenAlex stores abstracts as inverted index for efficiency. + Format: {"word": [position1, position2, ...], ...} + + Args: + inverted_index: Inverted index from OpenAlex + + Returns: + Reconstructed abstract text or None + + Examples: + >>> index = {"Hello": [0], "world": [1], "!": [2]} + >>> _reconstruct_abstract(index) + "Hello world !" + """ + if not inverted_index: + return None + + # Expand inverted index to (position, word) pairs + word_positions: list[tuple[int, str]] = [] + for word, positions in inverted_index.items(): + for pos in positions: + word_positions.append((pos, word)) + + # Sort by position and join + word_positions.sort() + return " ".join(word for _, word in word_positions) + + async def close(self) -> None: + """Close the HTTP client""" + await self.client.aclose() + + async def __aenter__(self) -> "OpenAlexClient": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: Any | None, + ) -> None: + await self.close() diff --git a/service/app/utils/literature/work_distributor.py b/service/app/utils/literature/work_distributor.py new file mode 100644 index 00000000..f4d7139c --- /dev/null +++ b/service/app/utils/literature/work_distributor.py @@ -0,0 +1,161 @@ +""" +Work distributor for coordinating multiple literature data sources +""" + +import inspect +import logging +from typing import Any + +from .doi_cleaner import deduplicate_by_doi +from .models import LiteratureWork, SearchRequest + +logger = logging.getLogger(__name__) + + +class WorkDistributor: + """ + Distribute search requests to multiple literature data sources + and aggregate results + """ + + def __init__(self, openalex_email: str | None = None) -> None: + """ + Initialize distributor with available clients + + Args: + openalex_email: Email for OpenAlex polite pool (required for OpenAlex) + """ + self.clients: dict[str, Any] = {} + self.openalex_email = openalex_email + self._register_clients() + + def _register_clients(self) -> None: + """Register available data source clients""" + # Import here to avoid circular dependencies + try: + from .openalex_client import OpenAlexClient + + self.clients["openalex"] = OpenAlexClient(email=self.openalex_email) + logger.info("Registered OpenAlex client") + except ImportError as e: + logger.warning(f"Failed to register OpenAlex client: {e}") + + # Future: Add more clients + # from .semantic_scholar_client import SemanticScholarClient + # self.clients["semantic_scholar"] = SemanticScholarClient() + + async def close(self) -> None: + """Close any underlying HTTP clients""" + for client in self.clients.values(): + close_method = getattr(client, "close", None) + if callable(close_method): + result = close_method() + if inspect.isawaitable(result): + await result + + async def __aenter__(self) -> "WorkDistributor": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: Any | None, + ) -> None: + await self.close() + + async def search(self, request: SearchRequest) -> dict[str, Any]: + """ + Execute search across multiple data sources and aggregate results + + Args: + request: Standardized search request + + Returns: + Dictionary containing: + - total_count: Total number of works fetched (before dedup) + - unique_count: Number of unique works (after dedup) + - sources: Dict of source name -> count + - works: List of deduplicated LiteratureWork objects + - warnings: List of warning/info messages for LLM feedback + + Examples: + >>> distributor = WorkDistributor() + >>> request = SearchRequest(query="machine learning", max_results=50) + >>> result = await distributor.search(request) + >>> print(f"Found {result['unique_count']} unique works") + """ + # Clamp max_results to 50/1000 with warnings + all_warnings: list[str] = [] + if request.max_results < 1: + all_warnings.append("⚠️ max_results < 1; using default 50") + request.max_results = 50 + elif request.max_results > 1000: + all_warnings.append("⚠️ max_results > 1000; using 1000") + request.max_results = 1000 + + # Determine which data sources to use + sources = request.data_sources or ["openalex"] + + # Collect works and warnings from all sources + all_works: list[LiteratureWork] = [] + source_counts: dict[str, int] = {} + + for source_name in sources: + if client := self.clients.get(source_name): + try: + logger.info(f"Fetching from {source_name}...") + works, warnings_data = await client.search(request) + all_warnings.extend(warnings_data) + + all_works.extend(works) + source_counts[source_name] = len(works) + logger.info(f"Fetched {len(works)} works from {source_name}") + except Exception as e: + logger.error(f"Error fetching from {source_name}: {e}", exc_info=True) + source_counts[source_name] = 0 + all_warnings.append(f"⚠️ Error fetching from {source_name}: {str(e)}") + else: + logger.warning(f"Data source '{source_name}' not available") + + # Deduplicate by DOI + logger.info(f"Deduplicating {len(all_works)} works...") + unique_works = deduplicate_by_doi(all_works) + logger.info(f"After deduplication: {len(unique_works)} unique works") + + # Sort results + unique_works = self._sort_works(unique_works, request.sort_by) + + # Limit to max_results + unique_works = unique_works[: request.max_results] + + return { + "total_count": len(all_works), + "unique_count": len(unique_works), + "sources": source_counts, + "works": unique_works, + "warnings": all_warnings, + } + + def _sort_works(self, works: list[LiteratureWork], sort_by: str) -> list[LiteratureWork]: + """ + Sort works by specified criteria + + Args: + works: List of works to sort + sort_by: Sort method - "relevance", "cited_by_count", "publication_date" + + Returns: + Sorted list of works + """ + if sort_by == "cited_by_count": + return sorted(works, key=lambda w: w.cited_by_count, reverse=True) + elif sort_by == "publication_date": + return sorted( + works, + key=lambda w: w.publication_year if w.publication_year else float("-inf"), + reverse=True, + ) + else: # relevance or default + # For relevance, keep original order (API returns by relevance) + return works diff --git a/service/tests/unit/test_utils/__init__.py b/service/tests/unit/test_literature/__init__.py similarity index 100% rename from service/tests/unit/test_utils/__init__.py rename to service/tests/unit/test_literature/__init__.py diff --git a/service/tests/unit/test_literature/test_base_client.py b/service/tests/unit/test_literature/test_base_client.py new file mode 100644 index 00000000..8271b87a --- /dev/null +++ b/service/tests/unit/test_literature/test_base_client.py @@ -0,0 +1,326 @@ +"""Tests for base literature client.""" + +import pytest + +from app.utils.literature.base_client import BaseLiteratureClient +from app.utils.literature.models import LiteratureWork, SearchRequest + + +class ConcreteClient(BaseLiteratureClient): + """Concrete implementation of BaseLiteratureClient for testing.""" + + async def search(self, request: SearchRequest) -> tuple[list[LiteratureWork], list[str]]: + """Dummy search implementation.""" + return [], [] + +<<<<<<< HEAD:service/tests/unit/test_literature/test_base_client.py +======= + async def get_by_doi(self, doi: str) -> LiteratureWork | None: + """Dummy get_by_doi implementation.""" + return None + + async def get_by_id(self, work_id: str) -> LiteratureWork | None: + """Dummy get_by_id implementation.""" + return None + +>>>>>>> 1794485dfcc48ad6b089f2b31eb788a021eadea5:service/tests/unit/test_utils/test_base_client.py + +class TestBaseLiteratureClientProtocol: + """Test BaseLiteratureClient protocol and abstract methods.""" + + def test_cannot_instantiate_abstract_class(self) -> None: + """Test that BaseLiteratureClient cannot be instantiated directly.""" + with pytest.raises(TypeError): + BaseLiteratureClient() # type: ignore + + def test_concrete_implementation(self) -> None: + """Test that concrete implementation can be instantiated.""" + client = ConcreteClient() + assert client is not None + assert isinstance(client, BaseLiteratureClient) + + @pytest.mark.asyncio + async def test_search_method_required(self) -> None: + """Test that search method is required.""" + request = SearchRequest(query="test") + result = await ConcreteClient().search(request) + assert result == ([], []) + +<<<<<<< HEAD:service/tests/unit/test_literature/test_base_client.py +======= + @pytest.mark.asyncio + async def test_get_by_doi_method_required(self) -> None: + """Test that get_by_doi method is required.""" + result = await ConcreteClient().get_by_doi("10.1038/nature12345") + assert result is None + + @pytest.mark.asyncio + async def test_get_by_id_method_required(self) -> None: + """Test that get_by_id method is required.""" + result = await ConcreteClient().get_by_id("W2741809807") + assert result is None + +>>>>>>> 1794485dfcc48ad6b089f2b31eb788a021eadea5:service/tests/unit/test_utils/test_base_client.py + +class TestSearchRequestDataclass: + """Test SearchRequest data model.""" + + def test_search_request_required_field(self) -> None: + """Test SearchRequest with required query field.""" + request = SearchRequest(query="machine learning") + assert request.query == "machine learning" + + def test_search_request_default_values(self) -> None: + """Test SearchRequest default values.""" + request = SearchRequest(query="test") + assert request.query == "test" + assert request.author is None + assert request.institution is None + assert request.source is None + assert request.year_from is None + assert request.year_to is None + assert request.is_oa is None + assert request.work_type is None + assert request.language is None + assert request.is_retracted is None + assert request.has_abstract is None + assert request.has_fulltext is None + assert request.sort_by == "relevance" + assert request.max_results == 50 + assert request.data_sources is None + + def test_search_request_all_fields(self) -> None: + """Test SearchRequest with all fields specified.""" + request = SearchRequest( + query="machine learning", + author="John Doe", + institution="MIT", + source="Nature", + year_from=2015, + year_to=2021, + is_oa=True, + work_type="journal-article", + language="en", + is_retracted=False, + has_abstract=True, + has_fulltext=True, + sort_by="cited_by_count", + max_results=100, + data_sources=["openalex", "semantic_scholar"], + ) + + assert request.query == "machine learning" + assert request.author == "John Doe" + assert request.institution == "MIT" + assert request.source == "Nature" + assert request.year_from == 2015 + assert request.year_to == 2021 + assert request.is_oa is True + assert request.work_type == "journal-article" + assert request.language == "en" + assert request.is_retracted is False + assert request.has_abstract is True + assert request.has_fulltext is True + assert request.sort_by == "cited_by_count" + assert request.max_results == 100 + assert request.data_sources == ["openalex", "semantic_scholar"] + + def test_search_request_partial_year_range(self) -> None: + """Test SearchRequest with only year_from.""" + request = SearchRequest(query="test", year_from=2015) + assert request.year_from == 2015 + assert request.year_to is None + + def test_search_request_partial_year_range_to_only(self) -> None: + """Test SearchRequest with only year_to.""" + request = SearchRequest(query="test", year_to=2021) + assert request.year_from is None + assert request.year_to == 2021 + + +class TestLiteratureWorkDataclass: + """Test LiteratureWork data model.""" + + def test_literature_work_minimal(self) -> None: + """Test LiteratureWork with minimal required fields.""" + work = LiteratureWork( + id="W123", + doi=None, + title="Test Paper", + authors=[], + publication_year=None, + cited_by_count=0, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + + assert work.id == "W123" + assert work.title == "Test Paper" + assert work.cited_by_count == 0 + assert work.source == "openalex" + + def test_literature_work_complete(self) -> None: + """Test LiteratureWork with all fields.""" + authors = [ + {"name": "John Doe", "id": "A1"}, + {"name": "Jane Smith", "id": "A2"}, + ] + + work = LiteratureWork( + id="W2741809807", + doi="10.1038/nature12345", + title="Machine Learning Fundamentals", + authors=authors, + publication_year=2020, + cited_by_count=150, + abstract="This is a comprehensive review of machine learning concepts.", + journal="Nature", + is_oa=True, + oa_url="https://example.com/paper.pdf", + source="openalex", + ) + + assert work.id == "W2741809807" + assert work.doi == "10.1038/nature12345" + assert work.title == "Machine Learning Fundamentals" + assert len(work.authors) == 2 + assert work.authors[0]["name"] == "John Doe" + assert work.publication_year == 2020 + assert work.cited_by_count == 150 + assert work.abstract is not None + assert work.journal == "Nature" + assert work.is_oa is True + assert work.oa_url is not None + + def test_literature_work_raw_data_default(self) -> None: + """Test LiteratureWork raw_data defaults to empty dict.""" + work = LiteratureWork( + id="W123", + doi=None, + title="Test", + authors=[], + publication_year=None, + cited_by_count=0, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + + assert work.raw_data == {} + + def test_literature_work_raw_data_custom(self) -> None: + """Test LiteratureWork with custom raw_data.""" + raw_data = {"custom_field": "value", "api_response": {"status": "ok"}} + + work = LiteratureWork( + id="W123", + doi=None, + title="Test", + authors=[], + publication_year=None, + cited_by_count=0, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + raw_data=raw_data, + ) + + assert work.raw_data == raw_data + assert work.raw_data["custom_field"] == "value" + + def test_literature_work_multiple_authors(self) -> None: + """Test LiteratureWork with multiple authors.""" + authors = [ + {"name": "Author 1", "id": "A1"}, + {"name": "Author 2", "id": None}, # Author without ID + {"name": "Author 3", "id": "A3"}, + ] + + work = LiteratureWork( + id="W123", + doi=None, + title="Test", + authors=authors, + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + + assert len(work.authors) == 3 + assert work.authors[1]["id"] is None + + def test_literature_work_comparison(self) -> None: + """Test LiteratureWork equality comparison.""" + work1 = LiteratureWork( + id="W123", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal="Nature", + is_oa=True, + oa_url=None, + source="openalex", + ) + + work2 = LiteratureWork( + id="W123", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal="Nature", + is_oa=True, + oa_url=None, + source="openalex", + ) + + # DataclassesObjects with same values should be equal + assert work1 == work2 + + def test_literature_work_inequality(self) -> None: + """Test LiteratureWork inequality.""" + work1 = LiteratureWork( + id="W123", + doi="10.1038/nature12345", + title="Paper 1", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + + work2 = LiteratureWork( + id="W456", + doi="10.1038/nature67890", + title="Paper 2", + authors=[], + publication_year=2021, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + + assert work1 != work2 diff --git a/service/tests/unit/test_literature/test_doi_cleaner.py b/service/tests/unit/test_literature/test_doi_cleaner.py new file mode 100644 index 00000000..c12e3839 --- /dev/null +++ b/service/tests/unit/test_literature/test_doi_cleaner.py @@ -0,0 +1,403 @@ +"""Tests for DOI normalization and deduplication utilities.""" + +import pytest + +from app.utils.literature.doi_cleaner import deduplicate_by_doi, normalize_doi +from app.utils.literature.models import LiteratureWork + + +class TestNormalizeDOI: + """Test DOI normalization functionality.""" + + def test_normalize_doi_with_https_prefix(self) -> None: + """Test normalizing DOI with https:// prefix.""" + result = normalize_doi("https://doi.org/10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_http_prefix(self) -> None: + """Test normalizing DOI with http:// prefix.""" + result = normalize_doi("http://doi.org/10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_dx_prefix(self) -> None: + """Test normalizing DOI with dx.doi.org prefix.""" + result = normalize_doi("https://dx.doi.org/10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_doi_colon_prefix(self) -> None: + """Test normalizing DOI with 'doi:' prefix.""" + result = normalize_doi("doi:10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_doi_prefix_uppercase(self) -> None: + """Test normalizing DOI with 'DOI:' prefix (uppercase).""" + result = normalize_doi("DOI: 10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_whitespace(self) -> None: + """Test normalizing DOI with leading/trailing whitespace.""" + result = normalize_doi(" 10.1038/nature12345 ") + assert result == "10.1038/nature12345" + + def test_normalize_doi_case_insensitive(self) -> None: + """Test that DOI normalization converts to lowercase.""" + result = normalize_doi("10.1038/NATURE12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_mixed_case_with_prefix(self) -> None: + """Test normalizing DOI with mixed case and prefix.""" + result = normalize_doi("https://DOI.ORG/10.1038/NATURE12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_none_input(self) -> None: + """Test normalizing None DOI returns None.""" + result = normalize_doi(None) + assert result is None + + def test_normalize_doi_empty_string(self) -> None: + """Test normalizing empty string returns None.""" + result = normalize_doi("") + assert result is None + + def test_normalize_doi_whitespace_only(self) -> None: + """Test normalizing whitespace-only string returns None.""" + result = normalize_doi(" ") + assert result is None + + def test_normalize_doi_invalid_format(self) -> None: + """Test normalizing invalid DOI format returns None.""" + result = normalize_doi("not-a-valid-doi") + assert result is None + + def test_normalize_doi_missing_prefix(self) -> None: + """Test normalizing DOI missing the '10.' prefix returns None.""" + result = normalize_doi("1038/nature12345") + assert result is None + + def test_normalize_doi_missing_suffix(self) -> None: + """Test normalizing DOI missing the suffix returns None.""" + result = normalize_doi("10.1038/") + assert result is None + + def test_normalize_doi_complex_suffix(self) -> None: + """Test normalizing DOI with complex suffix.""" + result = normalize_doi("10.1145/3580305.3599315") + assert result == "10.1145/3580305.3599315" + + def test_normalize_doi_with_version(self) -> None: + """Test normalizing DOI with version suffix.""" + result = normalize_doi("https://doi.org/10.1038/nature.2020.27710") + assert result == "10.1038/nature.2020.27710" + + +class TestDeduplicateByDOI: + """Test DOI-based deduplication functionality.""" + + @pytest.fixture + def sample_work(self) -> LiteratureWork: + """Create a sample literature work.""" + return LiteratureWork( + id="W2741809807", + doi="10.1038/nature12345", + title="Test Paper", + authors=[{"name": "John Doe", "id": "A1"}], + publication_year=2020, + cited_by_count=100, + abstract="Test abstract", + journal="Nature", + is_oa=True, + oa_url="https://example.com/paper.pdf", + source="openalex", + ) + + def test_deduplicate_empty_list(self) -> None: + """Test deduplicating empty list returns empty list.""" + result = deduplicate_by_doi([]) + assert result == [] + + def test_deduplicate_single_work(self, sample_work: LiteratureWork) -> None: + """Test deduplicating single work returns same work.""" + result = deduplicate_by_doi([sample_work]) + assert len(result) == 1 + assert result[0].id == sample_work.id + + def test_deduplicate_duplicate_doi_keeps_higher_citations(self, sample_work: LiteratureWork) -> None: + """Test deduplication keeps work with higher citation count.""" + work1 = LiteratureWork( + id="W1", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 1 + assert result[0].id == "W1" # Higher citation count + + def test_deduplicate_duplicate_doi_equal_citations_keeps_newer(self) -> None: + """Test deduplication keeps more recently published work when citation count is equal.""" + work1 = LiteratureWork( + id="W1", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2019, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 1 + assert result[0].id == "W2" # More recent publication + + def test_deduplicate_without_doi(self) -> None: + """Test deduplicating works without DOI.""" + work1 = LiteratureWork( + id="W1", + doi=None, + title="Paper 1", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi=None, + title="Paper 2", + authors=[], + publication_year=2020, + cited_by_count=20, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 2 # Both kept since no DOI + + def test_deduplicate_invalid_doi_treated_as_no_doi(self) -> None: + """Test deduplicating works with invalid DOI treats them as without DOI.""" + work1 = LiteratureWork( + id="W1", + doi="invalid-doi-format", + title="Paper 1", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature12345", + title="Paper 2", + authors=[], + publication_year=2020, + cited_by_count=20, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 2 + # Invalid DOI work should be in the results + assert any(w.id == "W1" for w in result) + assert any(w.id == "W2" for w in result) + + def test_deduplicate_doi_with_versions_deduplicated(self) -> None: + """Test deduplicating DOIs with version info.""" + work1 = LiteratureWork( + id="W1", + doi="https://doi.org/10.1038/nature.2020.27710", + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature.2020.27710", + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 1 + assert result[0].id == "W2" # Higher citation count + + def test_deduplicate_preserves_order_with_doi(self) -> None: + """Test that deduplication preserves order: DOI works first, then non-DOI.""" + work_no_doi = LiteratureWork( + id="W_no_doi", + doi=None, + title="No DOI", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + work_with_doi = LiteratureWork( + id="W_with_doi", + doi="10.1038/nature12345", + title="With DOI", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work_no_doi, work_with_doi]) + assert len(result) == 2 + assert result[0].id == "W_with_doi" # DOI works come first + assert result[1].id == "W_no_doi" + + def test_deduplicate_complex_scenario(self) -> None: + """Test deduplication with complex mix of works.""" + works = [ + # Duplicate pair with same DOI + LiteratureWork( + id="W1", + doi="10.1038/A", + title="A", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ), + LiteratureWork( + id="W2", + doi="10.1038/A", + title="A", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ), + # Another unique DOI + LiteratureWork( + id="W3", + doi="10.1038/B", + title="B", + authors=[], + publication_year=2021, + cited_by_count=75, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ), + # No DOI works + LiteratureWork( + id="W4", + doi=None, + title="C", + authors=[], + publication_year=2022, + cited_by_count=30, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ), + LiteratureWork( + id="W5", + doi=None, + title="D", + authors=[], + publication_year=2022, + cited_by_count=40, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ), + ] + + result = deduplicate_by_doi(works) + assert len(result) == 4 # W1 removed (duplicate), others kept + result_ids = {w.id for w in result} + assert result_ids == {"W2", "W3", "W4", "W5"} + # Verify W2 (higher citations) was kept over W1 + assert "W2" in result_ids + assert "W1" not in result_ids diff --git a/service/tests/unit/test_literature/test_openalex_client.py b/service/tests/unit/test_literature/test_openalex_client.py new file mode 100644 index 00000000..cddcab67 --- /dev/null +++ b/service/tests/unit/test_literature/test_openalex_client.py @@ -0,0 +1,594 @@ +"""Tests for OpenAlex API client.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from app.utils.literature.models import SearchRequest +from app.utils.literature.openalex_client import OpenAlexClient + + +class TestOpenAlexClientInit: + """Test OpenAlex client initialization.""" + + def test_client_initialization(self) -> None: + """Test client initializes with correct parameters.""" + email = "test@example.com" + rate_limit = 5 + timeout = 15.0 + + client = OpenAlexClient(email=email, rate_limit=rate_limit, timeout=timeout) + + assert client.email == email + assert client.rate_limit == rate_limit + assert client.pool_type == "polite" + assert pytest.approx(client.rate_limiter._min_interval, rel=0.01) == 1 / rate_limit + + def test_client_initialization_defaults(self) -> None: + """Test client initializes with default parameters.""" + email = "test@example.com" + client = OpenAlexClient(email=email) + + assert client.email == email + assert client.rate_limit == 10 + assert client.pool_type == "polite" + # Verify timeout was set (httpx Timeout object) + assert client.client.timeout is not None + + def test_client_initialization_default_pool(self) -> None: + """Test client initializes default pool when email is missing.""" + client = OpenAlexClient(email=None) + + assert client.email is None + assert client.rate_limit == 1 + assert client.pool_type == "default" + assert pytest.approx(client.rate_limiter._min_interval, rel=0.01) == 1.0 + + +class TestOpenAlexClientSearch: + """Test OpenAlex search functionality.""" + + @pytest.fixture + def client(self) -> OpenAlexClient: + """Create an OpenAlex client for testing.""" + return OpenAlexClient(email="test@example.com") + + @pytest.fixture + def mock_response(self) -> dict: + """Create a mock OpenAlex API response.""" + return { + "meta": {"count": 1, "page": 1}, + "results": [ + { + "id": "https://openalex.org/W2741809807", + "title": "Machine Learning Fundamentals", + "doi": "https://doi.org/10.1038/nature12345", + "publication_year": 2020, + "cited_by_count": 150, + "abstract_inverted_index": { + "Machine": [0], + "learning": [1], + "is": [2], + "fundamental": [3], + }, + "authorships": [ + { + "author": { + "id": "https://openalex.org/A5023888391", + "display_name": "Jane Smith", + } + } + ], + "primary_location": { + "source": { + "id": "https://openalex.org/S137773608", + "display_name": "Nature", + } + }, + "open_access": { + "is_oa": True, + "oa_url": "https://example.com/paper.pdf", + }, + } + ], + } + + @pytest.mark.asyncio + async def test_search_basic_query(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test basic search with simple query.""" + request = SearchRequest(query="machine learning", max_results=10) + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + assert works[0].title == "Machine Learning Fundamentals" + assert works[0].doi == "10.1038/nature12345" + assert isinstance(warnings, list) + + @pytest.mark.asyncio + async def test_search_with_author_filter(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test search with author filter.""" + request = SearchRequest(query="machine learning", author="Jane Smith", max_results=10) + + with patch.object(client, "_resolve_author_id", new_callable=AsyncMock) as mock_resolve: + mock_resolve.return_value = ("A5023888391", True, "✓ Author resolved") + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + mock_resolve.assert_called_once_with("Jane Smith") +<<<<<<< HEAD:service/tests/unit/test_literature/test_openalex_client.py + assert any("Author resolved" in msg for msg in warnings) +======= +>>>>>>> 1794485dfcc48ad6b089f2b31eb788a021eadea5:service/tests/unit/test_utils/test_openalex_client.py + + @pytest.mark.asyncio + async def test_search_with_institution_filter(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test search with institution filter.""" + request = SearchRequest(query="machine learning", institution="Harvard University", max_results=10) + + with patch.object(client, "_resolve_institution_id", new_callable=AsyncMock) as mock_resolve: + mock_resolve.return_value = ("I136199984", True, "✓ Institution resolved") + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + mock_resolve.assert_called_once_with("Harvard University") +<<<<<<< HEAD:service/tests/unit/test_literature/test_openalex_client.py + assert any("Institution resolved" in msg for msg in warnings) +======= +>>>>>>> 1794485dfcc48ad6b089f2b31eb788a021eadea5:service/tests/unit/test_utils/test_openalex_client.py + + @pytest.mark.asyncio + async def test_search_with_source_filter(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test search with source (journal) filter.""" + request = SearchRequest(query="machine learning", source="Nature", max_results=10) + + with patch.object(client, "_resolve_source_id", new_callable=AsyncMock) as mock_resolve: + mock_resolve.return_value = ("S137773608", True, "✓ Source resolved") + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + mock_resolve.assert_called_once_with("Nature") +<<<<<<< HEAD:service/tests/unit/test_literature/test_openalex_client.py + assert any("Source resolved" in msg for msg in warnings) +======= +>>>>>>> 1794485dfcc48ad6b089f2b31eb788a021eadea5:service/tests/unit/test_utils/test_openalex_client.py + + @pytest.mark.asyncio + async def test_search_with_year_range(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test search with year range filter.""" + request = SearchRequest(query="machine learning", year_from=2015, year_to=2021, max_results=10) + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + # Verify year filter was applied + call_args = mock_request.call_args + params = call_args[0][1] if call_args else {} + assert "2015-2021" in params.get("filter", "") + + @pytest.mark.asyncio + async def test_search_max_results_clamping_low(self, client: OpenAlexClient) -> None: + """Test that search handles low max_results correctly.""" + request = SearchRequest(query="test", max_results=0) + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = {"meta": {"count": 0}, "results": []} + + # Should not raise an error even with 0 max_results + works, warnings = await client.search(request) + assert isinstance(works, list) + assert isinstance(warnings, list) + + @pytest.mark.asyncio + async def test_search_max_results_clamping_high(self, client: OpenAlexClient) -> None: + """Test that search handles high max_results correctly.""" + request = SearchRequest(query="test", max_results=5000) + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = {"meta": {"count": 0}, "results": []} + + # Should not raise an error even with high max_results + works, warnings = await client.search(request) + assert isinstance(works, list) + assert isinstance(warnings, list) + + +<<<<<<< HEAD:service/tests/unit/test_literature/test_openalex_client.py +======= +class TestOpenAlexClientGetByDOI: + """Test OpenAlex get_by_doi functionality.""" + + @pytest.fixture + def client(self) -> OpenAlexClient: + """Create an OpenAlex client for testing.""" + return OpenAlexClient(email="test@example.com") + + @pytest.mark.asyncio + async def test_get_by_doi_success(self, client: OpenAlexClient) -> None: + """Test successful retrieval by DOI.""" + mock_work = { + "id": "https://openalex.org/W2741809807", + "title": "Test Paper", + "doi": "https://doi.org/10.1038/nature12345", + "publication_year": 2020, + "cited_by_count": 100, + "abstract_inverted_index": None, + "authorships": [], + "primary_location": None, + "open_access": {"is_oa": False}, + } + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_work + + result = await client.get_by_doi("10.1038/nature12345") + + assert result is not None + assert result.title == "Test Paper" + assert result.doi == "10.1038/nature12345" + + @pytest.mark.asyncio + async def test_get_by_doi_with_full_url(self, client: OpenAlexClient) -> None: + """Test retrieval by DOI with full URL.""" + mock_work = { + "id": "https://openalex.org/W2741809807", + "title": "Test Paper", + "doi": "https://doi.org/10.1038/nature12345", + "publication_year": 2020, + "cited_by_count": 100, + "abstract_inverted_index": None, + "authorships": [], + "primary_location": None, + "open_access": {"is_oa": False}, + } + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_work + + result = await client.get_by_doi("https://doi.org/10.1038/nature12345") + + assert result is not None + mock_request.assert_called_once() + + @pytest.mark.asyncio + async def test_get_by_doi_not_found(self, client: OpenAlexClient) -> None: + """Test retrieval by DOI when work not found.""" + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.side_effect = Exception("Not found") + + result = await client.get_by_doi("10.1038/invalid") + + assert result is None + + +class TestOpenAlexClientGetByID: + """Test OpenAlex get_by_id functionality.""" + + @pytest.fixture + def client(self) -> OpenAlexClient: + """Create an OpenAlex client for testing.""" + return OpenAlexClient(email="test@example.com") + + @pytest.mark.asyncio + async def test_get_by_id_success(self, client: OpenAlexClient) -> None: + """Test successful retrieval by OpenAlex ID.""" + mock_work = { + "id": "https://openalex.org/W2741809807", + "title": "Test Paper", + "doi": None, + "publication_year": 2020, + "cited_by_count": 50, + "abstract_inverted_index": None, + "authorships": [], + "primary_location": None, + "open_access": {"is_oa": False}, + } + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_work + + result = await client.get_by_id("W2741809807") + + assert result is not None + assert result.title == "Test Paper" + assert result.id == "W2741809807" + + @pytest.mark.asyncio + async def test_get_by_id_with_w_prefix(self, client: OpenAlexClient) -> None: + """Test retrieval by OpenAlex ID with W prefix.""" + mock_work = { + "id": "https://openalex.org/W2741809807", + "title": "Test Paper", + "doi": None, + "publication_year": 2020, + "cited_by_count": 50, + "abstract_inverted_index": None, + "authorships": [], + "primary_location": None, + "open_access": {"is_oa": False}, + } + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_work + + result = await client.get_by_id("W2741809807") + + assert result is not None + + +>>>>>>> 1794485dfcc48ad6b089f2b31eb788a021eadea5:service/tests/unit/test_utils/test_openalex_client.py +class TestOpenAlexClientPrivateMethods: + """Test OpenAlex client private methods.""" + + @pytest.fixture + def client(self) -> OpenAlexClient: + """Create an OpenAlex client for testing.""" + return OpenAlexClient(email="test@example.com") + + def test_build_query_params_basic(self, client: OpenAlexClient) -> None: + """Test building basic query parameters.""" + request = SearchRequest(query="machine learning", max_results=50) + params = client._build_query_params(request, None, None, None) + + assert params["search"] == "machine learning" + assert params["per-page"] == "200" + assert params["mailto"] == "test@example.com" + + def test_build_query_params_with_filters(self, client: OpenAlexClient) -> None: + """Test building query parameters with filters.""" + request = SearchRequest( + query="machine learning", + year_from=2015, + year_to=2021, + is_oa=True, + work_type="journal-article", + ) + params = client._build_query_params(request, None, None, None) + + assert "filter" in params + assert "publication_year:2015-2021" in params["filter"] + assert "is_oa:true" in params["filter"] + assert "type:journal-article" in params["filter"] + + def test_build_query_params_with_resolved_ids(self, client: OpenAlexClient) -> None: + """Test building query parameters with resolved author/institution/source IDs.""" + request = SearchRequest(query="test") + params = client._build_query_params(request, "A123", "I456", "S789") + + assert "filter" in params + assert "authorships.author.id:A123" in params["filter"] + assert "authorships.institutions.id:I456" in params["filter"] + assert "primary_location.source.id:S789" in params["filter"] + + def test_build_query_params_sorting_by_citations(self, client: OpenAlexClient) -> None: + """Test building query parameters with citation sorting.""" + request = SearchRequest(query="test", sort_by="cited_by_count") + params = client._build_query_params(request, None, None, None) + + assert params.get("sort") == "cited_by_count:desc" + + def test_build_query_params_sorting_by_date(self, client: OpenAlexClient) -> None: + """Test building query parameters with date sorting.""" + request = SearchRequest(query="test", sort_by="publication_date") + params = client._build_query_params(request, None, None, None) + + assert params.get("sort") == "publication_date:desc" + + def test_reconstruct_abstract_normal(self, client: OpenAlexClient) -> None: + """Test abstract reconstruction from inverted index.""" + inverted_index = { + "Machine": [0], + "learning": [1], + "is": [2], + "fundamental": [3], + } + + result = client._reconstruct_abstract(inverted_index) + + assert result == "Machine learning is fundamental" + + def test_reconstruct_abstract_with_duplicates(self, client: OpenAlexClient) -> None: + """Test abstract reconstruction with duplicate words.""" + inverted_index = { + "The": [0, 5], + "quick": [1], + "brown": [2], + "fox": [3], + "jumps": [4], + } + + result = client._reconstruct_abstract(inverted_index) + + assert result == "The quick brown fox jumps The" + + def test_reconstruct_abstract_none(self, client: OpenAlexClient) -> None: + """Test abstract reconstruction returns None for empty input.""" + result = client._reconstruct_abstract(None) + + assert result is None + + def test_reconstruct_abstract_empty(self, client: OpenAlexClient) -> None: + """Test abstract reconstruction returns None for empty dict.""" + result = client._reconstruct_abstract({}) + + assert result is None + + def test_transform_work_complete(self, client: OpenAlexClient) -> None: + """Test transforming complete OpenAlex work object.""" + work_data = { + "id": "https://openalex.org/W2741809807", + "title": "Machine Learning Fundamentals", + "doi": "https://doi.org/10.1038/nature12345", + "publication_year": 2020, + "cited_by_count": 150, + "abstract_inverted_index": {"Machine": [0], "learning": [1]}, + "authorships": [ + { + "author": { + "id": "https://openalex.org/A5023888391", + "display_name": "Jane Smith", + } + }, + { + "author": { + "id": "https://openalex.org/A5023888392", + "display_name": "John Doe", + } + }, + ], + "primary_location": { + "source": { + "id": "https://openalex.org/S137773608", + "display_name": "Nature", + } + }, + "open_access": { + "is_oa": True, + "oa_url": "https://example.com/paper.pdf", + }, + } + + result = client._transform_work(work_data) + + assert result.id == "W2741809807" + assert result.title == "Machine Learning Fundamentals" + assert result.doi == "10.1038/nature12345" + assert result.publication_year == 2020 + assert result.cited_by_count == 150 + assert len(result.authors) == 2 + assert result.authors[0]["name"] == "Jane Smith" + assert result.journal == "Nature" + assert result.is_oa is True + assert result.oa_url == "https://example.com/paper.pdf" + assert result.source == "openalex" + + def test_transform_work_minimal(self, client: OpenAlexClient) -> None: + """Test transforming minimal OpenAlex work object.""" + work_data = { + "id": "https://openalex.org/W123", + "title": "Minimal Paper", + "authorships": [], + } + + result = client._transform_work(work_data) + + assert result.id == "W123" + assert result.title == "Minimal Paper" + assert result.doi is None + assert result.authors == [] + assert result.journal is None + assert result.is_oa is False + + +class TestOpenAlexClientRequestWithRetry: + """Test OpenAlex client request retry logic.""" + + @pytest.fixture + def client(self) -> OpenAlexClient: + """Create an OpenAlex client for testing.""" + return OpenAlexClient(email="test@example.com") + + @pytest.mark.asyncio + async def test_request_with_retry_success(self, client: OpenAlexClient) -> None: + """Test successful request without retry.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"success": True} + + with patch.object(client.client, "get", new_callable=AsyncMock) as mock_get: + mock_get.return_value = mock_response + + result = await client._request_with_retry("http://test.com", {}) + + assert result == {"success": True} + mock_get.assert_called_once() + + @pytest.mark.asyncio + async def test_request_with_retry_timeout(self, client: OpenAlexClient) -> None: + """Test request retry on timeout.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"success": True} + + with patch.object(client.client, "get", new_callable=AsyncMock) as mock_get: + # First call timeout, second call success + mock_get.side_effect = [httpx.TimeoutException("timeout"), mock_response] + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await client._request_with_retry("http://test.com", {}) + + assert result == {"success": True} + assert mock_get.call_count == 2 + + @pytest.mark.asyncio + async def test_request_with_retry_rate_limit(self, client: OpenAlexClient) -> None: + """Test request retry on rate limit (403).""" + mock_response_403 = MagicMock() + mock_response_403.status_code = 403 + + mock_response_200 = MagicMock() + mock_response_200.status_code = 200 + mock_response_200.json.return_value = {"success": True} + + with patch.object(client.client, "get", new_callable=AsyncMock) as mock_get: + mock_get.side_effect = [mock_response_403, mock_response_200] + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await client._request_with_retry("http://test.com", {}) + + assert result == {"success": True} + assert mock_get.call_count == 2 + + @pytest.mark.asyncio + async def test_request_with_retry_server_error(self, client: OpenAlexClient) -> None: + """Test request retry on server error (5xx).""" + mock_response_500 = MagicMock() + mock_response_500.status_code = 500 + + mock_response_200 = MagicMock() + mock_response_200.status_code = 200 + mock_response_200.json.return_value = {"success": True} + + with patch.object(client.client, "get", new_callable=AsyncMock) as mock_get: + mock_get.side_effect = [mock_response_500, mock_response_200] + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await client._request_with_retry("http://test.com", {}) + + assert result == {"success": True} + assert mock_get.call_count == 2 + + +class TestOpenAlexClientContextManager: + """Test OpenAlex client context manager.""" + + @pytest.mark.asyncio + async def test_context_manager_enter_exit(self) -> None: + """Test client works as async context manager.""" + async with OpenAlexClient(email="test@example.com") as client: + assert client is not None + assert client.email == "test@example.com" + + @pytest.mark.asyncio + async def test_close_method(self) -> None: + """Test client close method.""" + client = OpenAlexClient(email="test@example.com") + with patch.object(client.client, "aclose", new_callable=AsyncMock) as mock_close: + await client.close() + mock_close.assert_called_once() diff --git a/service/tests/unit/test_literature/test_work_distributor.py b/service/tests/unit/test_literature/test_work_distributor.py new file mode 100644 index 00000000..0c8aa2fb --- /dev/null +++ b/service/tests/unit/test_literature/test_work_distributor.py @@ -0,0 +1,428 @@ +"""Tests for work distributor.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.utils.literature.models import LiteratureWork, SearchRequest +from app.utils.literature.work_distributor import WorkDistributor + + +class TestWorkDistributorInit: + """Test WorkDistributor initialization.""" + + def test_init_with_openalex_email(self) -> None: + """Test initialization with OpenAlex email.""" + distributor = WorkDistributor(openalex_email="test@example.com") + + assert distributor.openalex_email == "test@example.com" + # OpenAlex client should be registered (polite pool) + assert "openalex" in distributor.clients + + def test_init_without_openalex_email(self) -> None: + """Test initialization without OpenAlex email.""" + distributor = WorkDistributor() + + assert distributor.openalex_email is None + # OpenAlex client should still be registered (default pool) + assert "openalex" in distributor.clients + + def test_init_with_import_error(self) -> None: + """Test initialization when OpenAlex client import fails.""" + # This test would require mocking the import, which is complex + # Instead, just verify initialization works without email + distributor = WorkDistributor() + + assert distributor.openalex_email is None + assert "openalex" in distributor.clients + + +class TestWorkDistributorSearch: + """Test WorkDistributor search functionality.""" + + @pytest.fixture + def sample_work(self) -> LiteratureWork: + """Create a sample literature work.""" + return LiteratureWork( + id="W1", + doi="10.1038/nature12345", + title="Test Paper", + authors=[{"name": "John Doe", "id": "A1"}], + publication_year=2020, + cited_by_count=100, + abstract="Test abstract", + journal="Nature", + is_oa=True, + oa_url="https://example.com/paper.pdf", + source="openalex", + ) + + @pytest.fixture + def mock_openalex_client(self, sample_work: LiteratureWork) -> MagicMock: + """Create a mock OpenAlex client.""" + client = AsyncMock() + client.search = AsyncMock(return_value=([sample_work], ["✓ Search completed"])) + return client + + @pytest.mark.asyncio + async def test_search_basic(self, sample_work: LiteratureWork, mock_openalex_client: MagicMock) -> None: + """Test basic search with default source.""" + request = SearchRequest(query="test", max_results=50) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_openalex_client} + distributor.openalex_email = "test@example.com" + + result = await distributor.search(request) + + assert result["total_count"] == 1 + assert result["unique_count"] == 1 + assert "openalex" in result["sources"] + assert len(result["works"]) == 1 + assert result["works"][0].id == "W1" + + @pytest.mark.asyncio + async def test_search_multiple_sources(self, sample_work: LiteratureWork) -> None: + """Test search with multiple data sources.""" + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature67890", + title="Another Paper", + authors=[], + publication_year=2021, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="semantic_scholar", + ) + + mock_client1 = AsyncMock() + mock_client1.search = AsyncMock(return_value=([sample_work], [])) + + mock_client2 = AsyncMock() + mock_client2.search = AsyncMock(return_value=([work2], [])) + + request = SearchRequest(query="test", max_results=50, data_sources=["openalex", "semantic_scholar"]) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client1, "semantic_scholar": mock_client2} + + result = await distributor.search(request) + + assert result["total_count"] == 2 + assert result["unique_count"] == 2 + assert "openalex" in result["sources"] + assert "semantic_scholar" in result["sources"] + + @pytest.mark.asyncio + async def test_search_deduplication(self) -> None: + """Test search deduplicates results by DOI.""" + work1 = LiteratureWork( + id="W1", + doi="10.1038/nature12345", + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature12345", + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="other", + ) + + mock_client = AsyncMock() + mock_client.search = AsyncMock(return_value=([work1, work2], [])) + + request = SearchRequest(query="test", max_results=50) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client} + + result = await distributor.search(request) + + assert result["total_count"] == 2 + assert result["unique_count"] == 1 # Deduplicated + assert result["works"][0].id == "W1" # Higher citation count + + @pytest.mark.asyncio + async def test_search_with_client_error(self, sample_work: LiteratureWork) -> None: + """Test search handles client errors gracefully.""" + mock_client = AsyncMock() + mock_client.search = AsyncMock(side_effect=Exception("API Error")) + + request = SearchRequest(query="test", max_results=50) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client} + + result = await distributor.search(request) + + assert result["total_count"] == 0 + assert result["unique_count"] == 0 + assert result["sources"]["openalex"] == 0 + assert any("Error" in w for w in result["warnings"]) + + @pytest.mark.asyncio + async def test_search_unavailable_source(self) -> None: + """Test search with unavailable data source.""" + request = SearchRequest(query="test", max_results=50, data_sources=["unavailable_source"]) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {} + + result = await distributor.search(request) + + assert result["total_count"] == 0 + assert result["unique_count"] == 0 + assert result["works"] == [] + + @pytest.mark.asyncio + async def test_search_max_results_clamping_low(self) -> None: + """Test search clamps max_results to minimum.""" + request = SearchRequest(query="test", max_results=0) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {} + + result = await distributor.search(request) + + assert any("max_results < 1" in w for w in result["warnings"]) + assert request.max_results == 50 + + @pytest.mark.asyncio + async def test_search_max_results_clamping_high(self) -> None: + """Test search clamps max_results to maximum.""" + request = SearchRequest(query="test", max_results=5000) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {} + + result = await distributor.search(request) + + assert any("max_results > 1000" in w for w in result["warnings"]) + assert request.max_results == 1000 + + @pytest.mark.asyncio + async def test_search_result_limiting(self) -> None: + """Test search limits results to max_results.""" + works = [ + LiteratureWork( + id=f"W{i}", + doi=f"10.1038/paper{i}", + title=f"Paper {i}", + authors=[], + publication_year=2020, + cited_by_count=100 - i, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + for i in range(20) + ] + + mock_client = AsyncMock() + mock_client.search = AsyncMock(return_value=(works, [])) + + request = SearchRequest(query="test", max_results=10) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client} + + result = await distributor.search(request) + + assert len(result["works"]) == 10 + + @pytest.mark.asyncio + async def test_search_with_warnings(self) -> None: + """Test search collects warnings from clients.""" + work = LiteratureWork( + id="W1", + doi=None, + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ) + + mock_client = AsyncMock() + mock_client.search = AsyncMock( + return_value=( + [work], + ["⚠️ Author not found", "✓ Search completed"], + ) + ) + + request = SearchRequest(query="test", max_results=50) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client} + + result = await distributor.search(request) + + assert "⚠️ Author not found" in result["warnings"] + assert "✓ Search completed" in result["warnings"] + + +class TestWorkDistributorSorting: + """Test WorkDistributor sorting functionality.""" + + @pytest.fixture + def sample_works(self) -> list[LiteratureWork]: + """Create sample works for sorting tests.""" + return [ + LiteratureWork( + id="W1", + doi=None, + title="Paper 1", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ), + LiteratureWork( + id="W2", + doi=None, + title="Paper 2", + authors=[], + publication_year=2021, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ), + LiteratureWork( + id="W3", + doi=None, + title="Paper 3", + authors=[], + publication_year=2019, + cited_by_count=75, + abstract=None, + journal=None, + is_oa=False, + oa_url=None, + source="openalex", + ), + ] + + def test_sort_by_relevance(self, sample_works: list[LiteratureWork]) -> None: + """Test sorting by relevance (default, maintains order).""" + distributor = WorkDistributor.__new__(WorkDistributor) + + result = distributor._sort_works(sample_works, "relevance") + + # Should maintain original order for relevance + assert result[0].id == "W1" + assert result[1].id == "W2" + assert result[2].id == "W3" + + def test_sort_by_cited_by_count(self, sample_works: list[LiteratureWork]) -> None: + """Test sorting by citation count.""" + distributor = WorkDistributor.__new__(WorkDistributor) + + result = distributor._sort_works(sample_works, "cited_by_count") + + assert result[0].id == "W2" # 100 citations + assert result[1].id == "W3" # 75 citations + assert result[2].id == "W1" # 50 citations + + def test_sort_by_publication_date(self, sample_works: list[LiteratureWork]) -> None: + """Test sorting by publication date.""" + distributor = WorkDistributor.__new__(WorkDistributor) + + result = distributor._sort_works(sample_works, "publication_date") + + assert result[0].id == "W2" # 2021 + assert result[1].id == "W1" # 2020 + assert result[2].id == "W3" # 2019 + + def test_sort_with_missing_year(self, sample_works: list[LiteratureWork]) -> None: + """Test sorting by publication date with missing years.""" + sample_works[1].publication_year = None + + distributor = WorkDistributor.__new__(WorkDistributor) + + result = distributor._sort_works(sample_works, "publication_date") + + # Works with missing year should go to the end + assert result[0].id == "W1" # 2020 + assert result[1].id == "W3" # 2019 + assert result[2].publication_year is None + + +class TestWorkDistributorContextManager: + """Test WorkDistributor context manager.""" + + @pytest.mark.asyncio + async def test_context_manager_enter_exit(self) -> None: + """Test context manager functionality.""" + async with WorkDistributor(openalex_email="test@example.com") as distributor: + assert distributor is not None + + @pytest.mark.asyncio + async def test_close_method(self) -> None: + """Test close method.""" + distributor = WorkDistributor(openalex_email="test@example.com") + + # Replace the actual client with a mock + mock_client = MagicMock() + mock_client.close = AsyncMock() + distributor.clients["openalex"] = mock_client + + await distributor.close() + + mock_client.close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_with_sync_close(self) -> None: + """Test close method with synchronous close.""" + distributor = WorkDistributor.__new__(WorkDistributor) + + mock_client = MagicMock() + # Synchronous close (returns None, not awaitable) + mock_client.close = MagicMock(return_value=None) + distributor.clients = {"openalex": mock_client} + + await distributor.close() + + mock_client.close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_with_no_close_method(self) -> None: + """Test close method with client that has no close method.""" + distributor = WorkDistributor.__new__(WorkDistributor) + + mock_client = MagicMock(spec=[]) # No close method + distributor.clients = {"openalex": mock_client} + + # Should not raise an error + await distributor.close() diff --git a/service/tests/unit/test_utils/test_built_in_tools.py b/service/tests/unit/test_utils/test_built_in_tools.py deleted file mode 100644 index 61ca5ebb..00000000 --- a/service/tests/unit/test_utils/test_built_in_tools.py +++ /dev/null @@ -1,227 +0,0 @@ -"""Tests for built-in tools utilities.""" - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from fastmcp import FastMCP - -from app.mcp.builtin_tools import register_built_in_tools - - -class TestBuiltInTools: - """Test built-in tools registration and functionality.""" - - @pytest.fixture - def mock_mcp(self): - """Create a mock FastMCP instance.""" - mcp = MagicMock(spec=FastMCP) - mcp.tool = MagicMock() - mcp.resource = MagicMock() - return mcp - - def test_register_built_in_tools(self, mock_mcp: MagicMock) -> None: - """Test that built-in tools are registered properly.""" - register_built_in_tools(mock_mcp) - - # Verify that the decorators were called (tools were registered) - assert mock_mcp.tool.call_count >= 4 # We have at least 4 tools - assert mock_mcp.resource.call_count >= 1 # We have at least 1 resource - - @patch("app.mcp.builtin_tools.request.urlopen") - def test_search_github_success(self, mock_urlopen: MagicMock, mock_mcp: MagicMock) -> None: - """Test GitHub search tool with successful response.""" - # Mock response data - mock_response_data = { - "items": [ - { - "full_name": "test/repo1", - "html_url": "https://github.com/test/repo1", - "description": "Test repository 1", - "stargazers_count": 100, - "forks_count": 20, - "language": "Python", - "updated_at": "2024-01-01T00:00:00Z", - "topics": ["test", "demo"], - }, - { - "full_name": "test/repo2", - "html_url": "https://github.com/test/repo2", - "description": "Test repository 2", - "stargazers_count": 50, - "forks_count": 10, - "language": "Python", - "updated_at": "2024-01-02T00:00:00Z", - "topics": [], - }, - ] - } - - # Mock the context manager and JSON loading - mock_response = MagicMock() - mock_response.__enter__ = MagicMock(return_value=mock_response) - mock_response.__exit__ = MagicMock(return_value=None) - mock_urlopen.return_value = mock_response - - with patch("app.mcp.builtin_tools.json.load") as mock_json_load: - mock_json_load.return_value = mock_response_data - - register_built_in_tools(mock_mcp) - - # Get the search_github function from the registered tools - # Since we can't easily extract it, we'll test the logic directly - # by calling the function that would be registered - - # For this test, we'll verify the mock was set up correctly - assert mock_mcp.tool.called - - @patch("app.mcp.builtin_tools.request.urlopen") - def test_search_github_empty_query(self, mock_urlopen: MagicMock, mock_mcp: MagicMock) -> None: - """Test GitHub search with empty query.""" - register_built_in_tools(mock_mcp) - - # The actual test would need access to the registered function - # For now, we verify the registration happened - assert mock_mcp.tool.called - - @patch("app.mcp.builtin_tools.request.urlopen") - def test_search_github_api_error(self, mock_urlopen: MagicMock, mock_mcp: MagicMock) -> None: - """Test GitHub search with API error.""" - # Mock URL open to raise an exception - mock_urlopen.side_effect = Exception("API Error") - - register_built_in_tools(mock_mcp) - - # Verify registration still happened despite the error not occurring yet - assert mock_mcp.tool.called - - def test_search_github_parameters(self, mock_mcp: MagicMock) -> None: - """Test GitHub search with different parameters.""" - register_built_in_tools(mock_mcp) - - # Verify the tool was registered with proper signature - assert mock_mcp.tool.called - - # The actual function would accept parameters like query, max_results, sort_by - # Since we can't easily test the registered function directly, - # we verify the registration process - - async def test_llm_web_search_no_auth(self, mock_mcp: MagicMock) -> None: - """Test LLM web search without authentication.""" - with patch("app.mcp.builtin_tools.get_access_token") as mock_get_token: - mock_get_token.return_value = None - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - async def test_llm_web_search_with_auth(self, mock_mcp: MagicMock) -> None: - """Test LLM web search with authentication.""" - with ( - patch("fastmcp.server.dependencies.get_access_token") as mock_get_token, - patch("app.middleware.auth.AuthProvider") as mock_auth_provider, - patch("app.core.providers.get_user_provider_manager") as mock_get_manager, - patch("app.infra.database.connection.AsyncSessionLocal") as mock_session, - ): - # Mock authentication - mock_token = MagicMock() - mock_token.claims = {"user_id": "test-user"} - mock_get_token.return_value = mock_token - - mock_user_info = MagicMock() - mock_user_info.id = "test-user" - mock_auth_provider.parse_user_info.return_value = mock_user_info - - # Mock database session - mock_db = AsyncMock() - mock_session.return_value.__aenter__.return_value = mock_db - - # Mock provider manager - mock_provider_manager = AsyncMock() - mock_get_manager.return_value = mock_provider_manager - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - async def test_refresh_tools_success(self, mock_mcp: MagicMock) -> None: - """Test refresh tools functionality.""" - with ( - patch("app.mcp.builtin_tools.get_access_token") as mock_get_token, - patch("app.mcp.builtin_tools.AuthProvider") as mock_auth_provider, - patch("app.mcp.builtin_tools.tool_loader") as mock_tool_loader, - ): - # Mock authentication - mock_token = MagicMock() - mock_token.claims = {"user_id": "test-user"} - mock_get_token.return_value = mock_token - - mock_user_info = MagicMock() - mock_user_info.id = "test-user" - mock_auth_provider.parse_user_info.return_value = mock_user_info - - # Mock tool loader - mock_tool_loader.refresh_tools.return_value = { - "added": ["tool1", "tool2"], - "removed": ["old_tool"], - "updated": ["updated_tool"], - } - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - async def test_refresh_tools_no_auth(self, mock_mcp: MagicMock) -> None: - """Test refresh tools without authentication.""" - with patch("app.mcp.builtin_tools.get_access_token") as mock_get_token: - mock_get_token.return_value = None - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - def test_get_server_status(self, mock_mcp: MagicMock) -> None: - """Test get server status tool.""" - with patch("app.mcp.builtin_tools.tool_loader") as mock_tool_loader: - mock_proxy_manager = MagicMock() - mock_proxy_manager.list_proxies.return_value = ["proxy1", "proxy2"] - mock_tool_loader.proxy_manager = mock_proxy_manager - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - @pytest.mark.parametrize("sort_by", ["stars", "forks", "updated"]) - def test_search_github_sort_options(self, mock_mcp: MagicMock, sort_by: str) -> None: - """Test GitHub search with different sort options.""" - register_built_in_tools(mock_mcp) - - # Verify the tool registration happened - assert mock_mcp.tool.called - - def test_tools_registration_count(self, mock_mcp: MagicMock) -> None: - """Test that the expected number of tools are registered.""" - register_built_in_tools(mock_mcp) - - # We expect at least these tools: - # - search_github - # - llm_web_search - # - refresh_tools - # - get_server_status - expected_min_tools = 4 - - assert mock_mcp.tool.call_count >= expected_min_tools - - def test_resource_registration_count(self, mock_mcp: MagicMock) -> None: - """Test that the expected number of resources are registered.""" - register_built_in_tools(mock_mcp) - - # We expect at least these resources: - # - config://server - expected_min_resources = 1 - - assert mock_mcp.resource.call_count >= expected_min_resources