|
| 1 | +# Copyright (c) 2023-2024 Datalayer, Inc. |
| 2 | +# |
| 3 | +# BSD 3-Clause License |
| 4 | + |
| 5 | +"""cite from a notebook.""" |
| 6 | + |
| 7 | +from typing import Any, Optional |
| 8 | +from mcp.server.fastmcp.prompts.base import UserMessage |
| 9 | +from jupyter_server_client import JupyterServerClient |
| 10 | +from jupyter_mcp_server.tools._base import BaseTool, ServerMode |
| 11 | +from jupyter_mcp_server.notebook_manager import NotebookManager |
| 12 | +from jupyter_mcp_server.models import Notebook |
| 13 | + |
| 14 | + |
| 15 | +class JupyterCitePrompt(BaseTool): |
| 16 | + """Tool to cite specific cells from specified notebook.""" |
| 17 | + |
| 18 | + def _parse_cell_indices(self, cell_indices_str: str, max_cells: int) -> list[int]: |
| 19 | + """ |
| 20 | + Parse cell indices from a string with flexible format. |
| 21 | + |
| 22 | + Supports formats like: |
| 23 | + - '0,1,2' for individual indices |
| 24 | + - '0-2' for ranges |
| 25 | + - '0-2,4' for mixed format |
| 26 | + - '3-' for from index 3 to end |
| 27 | + |
| 28 | + Args: |
| 29 | + cell_indices_str: String with cell indices |
| 30 | + max_cells: Maximum number of cells in the notebook |
| 31 | + |
| 32 | + Returns: |
| 33 | + List of integer cell indices |
| 34 | + |
| 35 | + Raises: |
| 36 | + ValueError: If indices are invalid or out of range |
| 37 | + """ |
| 38 | + if not cell_indices_str or not cell_indices_str.strip(): |
| 39 | + raise ValueError("Cell indices cannot be empty") |
| 40 | + |
| 41 | + # Check if notebook is empty |
| 42 | + if max_cells <= 0: |
| 43 | + raise ValueError("Notebook has no cells") |
| 44 | + |
| 45 | + result = set() |
| 46 | + parts = cell_indices_str.split(',') |
| 47 | + |
| 48 | + for part in parts: |
| 49 | + part = part.strip() |
| 50 | + if not part: |
| 51 | + continue |
| 52 | + |
| 53 | + if '-' in part: |
| 54 | + # Handle range format |
| 55 | + range_parts = part.split('-', 1) |
| 56 | + |
| 57 | + if len(range_parts) == 2: |
| 58 | + start_str, end_str = range_parts |
| 59 | + |
| 60 | + if not start_str: |
| 61 | + raise ValueError(f"Invalid range format: {part}") |
| 62 | + |
| 63 | + try: |
| 64 | + start = int(start_str) |
| 65 | + except ValueError: |
| 66 | + raise ValueError(f"Invalid start index: {start_str}") |
| 67 | + |
| 68 | + if start < 0: |
| 69 | + raise ValueError(f"Start index cannot be negative: {start}") |
| 70 | + |
| 71 | + if not end_str: |
| 72 | + # Case: '3-' means from 3 to end |
| 73 | + end = max_cells - 1 |
| 74 | + # Check if start is within range |
| 75 | + if start >= max_cells: |
| 76 | + raise ValueError(f"Cell index {start} is out of range. Notebook has {max_cells} cells.") |
| 77 | + else: |
| 78 | + try: |
| 79 | + end = int(end_str) |
| 80 | + except ValueError: |
| 81 | + raise ValueError(f"Invalid end index: {end_str}") |
| 82 | + |
| 83 | + if end < start: |
| 84 | + raise ValueError(f"End index ({end}) must be greater than or equal to start index ({start})") |
| 85 | + else: |
| 86 | + raise ValueError(f"Invalid range format: {part}") |
| 87 | + |
| 88 | + # Add all indices in the range |
| 89 | + for i in range(start, end + 1): |
| 90 | + if i >= max_cells: |
| 91 | + raise ValueError(f"Cell index {i} is out of range. Notebook has {max_cells} cells.") |
| 92 | + result.add(i) |
| 93 | + else: |
| 94 | + # Handle single index |
| 95 | + try: |
| 96 | + index = int(part) |
| 97 | + except ValueError: |
| 98 | + raise ValueError(f"Invalid cell index: {part}") |
| 99 | + |
| 100 | + if index < 0: |
| 101 | + raise ValueError(f"Cell index cannot be negative: {index}") |
| 102 | + if index >= max_cells: |
| 103 | + raise ValueError(f"Cell index {index} is out of range. Notebook has {max_cells} cells.") |
| 104 | + |
| 105 | + result.add(index) |
| 106 | + |
| 107 | + # Convert to sorted list |
| 108 | + return sorted(result) |
| 109 | + |
| 110 | + async def execute( |
| 111 | + self, |
| 112 | + mode: ServerMode, |
| 113 | + server_client: Optional[JupyterServerClient] = None, |
| 114 | + contents_manager: Optional[Any] = None, |
| 115 | + notebook_manager: Optional[NotebookManager] = None, |
| 116 | + cell_indices: Optional[str] = None, |
| 117 | + notebook_name: Optional[str] = None, |
| 118 | + prompt: Optional[str] = None, |
| 119 | + **kwargs |
| 120 | + ) -> str: |
| 121 | + """Execute the read_notebook tool. |
| 122 | + |
| 123 | + Args: |
| 124 | + mode: Server mode (MCP_SERVER or JUPYTER_SERVER) |
| 125 | + contents_manager: Direct API access for JUPYTER_SERVER mode |
| 126 | + notebook_manager: Notebook manager instance |
| 127 | + notebook_name: Notebook identifier to read |
| 128 | + response_format: Response format (brief or detailed) |
| 129 | + start_index: Starting index for pagination (0-based) |
| 130 | + limit: Maximum number of items to return (0 means no limit) |
| 131 | + **kwargs: Additional parameters |
| 132 | + |
| 133 | + Returns: |
| 134 | + Formatted table with cell information |
| 135 | + """ |
| 136 | + if notebook_name == "": |
| 137 | + notebook_name = notebook_manager._current_notebook |
| 138 | + if notebook_name not in notebook_manager: |
| 139 | + raise ValueError(f"Notebook '{notebook_name}' is not connected. All currently connected notebooks: {list(notebook_manager.list_all_notebooks().keys())}") |
| 140 | + |
| 141 | + if mode == ServerMode.JUPYTER_SERVER and contents_manager is not None: |
| 142 | + # Local mode: read notebook directly from file system |
| 143 | + notebook_path = notebook_manager.get_notebook_path(notebook_name) |
| 144 | + |
| 145 | + model = await contents_manager.get(notebook_path, content=True, type='notebook') |
| 146 | + if 'content' not in model: |
| 147 | + raise ValueError(f"Could not read notebook content from {notebook_path}") |
| 148 | + notebook = Notebook(**model['content']) |
| 149 | + elif mode == ServerMode.MCP_SERVER and notebook_manager is not None: |
| 150 | + # Remote mode: use WebSocket connection to Y.js document |
| 151 | + async with notebook_manager.get_notebook_connection(notebook_name) as notebook_content: |
| 152 | + notebook = Notebook(**notebook_content.as_dict()) |
| 153 | + else: |
| 154 | + raise ValueError(f"Invalid mode or missing required clients: mode={mode}") |
| 155 | + |
| 156 | + # Parse cell indices with flexible format |
| 157 | + parsed_indices = self._parse_cell_indices(cell_indices, len(notebook)) |
| 158 | + |
| 159 | + prompt_list = [f"USER Cite cells {parsed_indices} from notebook {notebook_name}, here are the cells:"] |
| 160 | + for cell_index in parsed_indices: |
| 161 | + cell = notebook.cells[cell_index] |
| 162 | + prompt_list.append(f"=====Cell {cell_index} | type: {cell.cell_type} | execution count: {cell.execution_count if cell.execution_count else 'N/A'}=====") |
| 163 | + prompt_list.append(cell.get_source('readable')) |
| 164 | + |
| 165 | + prompt_list.append("=====End of Cited Cells=====") |
| 166 | + prompt_list.append(f"USER's Instruction are follow: {prompt}") |
| 167 | + |
| 168 | + return [UserMessage(content="\n".join(prompt_list))] |
| 169 | + |
| 170 | + |
| 171 | + |
0 commit comments