""" Custom MCP client using direct subprocess communication. This bypasses the buggy stdio_client from mcp.client.stdio. """ import asyncio import json import logging import subprocess import sys from pathlib import Path from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) class MCPClient: """Custom MCP client using direct subprocess communication""" def __init__(self, server_script: str, server_name: str): self.server_script = server_script self.server_name = server_name self.process: Optional[subprocess.Popen] = None self.message_id = 0 self._initialized = False self.script_path = server_script # Store for potential restart async def start(self): """Start the MCP server subprocess""" logger.info(f"Starting MCP server: {self.server_name}") self.process = subprocess.Popen( [sys.executable, self.server_script], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1 # Line-buffered I/O to prevent 8KB truncation ) # Initialize the session await self._initialize() logger.info(f"Successfully started MCP server: {self.server_name}") async def _initialize(self): """Initialize the MCP session""" init_message = { "jsonrpc": "2.0", "id": self._next_id(), "method": "initialize", "params": { "protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": { "name": "als-research-agent", "version": "1.0.0" } } } response = await self._send_request(init_message) if "result" in response: self._initialized = True logger.info(f"Initialized {self.server_name}: {response['result'].get('serverInfo', {})}") else: raise Exception(f"Initialization failed: {response}") def _next_id(self) -> int: """Get next message ID""" self.message_id += 1 return self.message_id async def _send_request(self, message: Dict[str, Any]) -> Dict[str, Any]: """Send a JSON-RPC request and wait for response""" if not self.process: raise RuntimeError("Server not started") # Check if process is still alive if self.process.poll() is not None: # Process has terminated raise RuntimeError(f"Server {self.server_name} has terminated unexpectedly") # Send request request_json = json.dumps(message) + "\n" self.process.stdin.write(request_json) self.process.stdin.flush() # Read response with timeout try: response_line = await asyncio.wait_for( asyncio.to_thread(self.process.stdout.readline), timeout=60.0 # Extended timeout for LlamaIndex/RAG server initialization ) if not response_line: raise Exception("Server closed stdout") return json.loads(response_line) except asyncio.TimeoutError: raise Exception("Request timed out") async def list_tools(self) -> List[Dict[str, Any]]: """List available tools""" if not self._initialized: raise RuntimeError("Client not initialized") message = { "jsonrpc": "2.0", "id": self._next_id(), "method": "tools/list", "params": {} } response = await self._send_request(message) if "result" in response: return response["result"].get("tools", []) else: raise Exception(f"List tools failed: {response}") async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str: """Call a tool""" if not self._initialized: raise RuntimeError("Client not initialized") message = { "jsonrpc": "2.0", "id": self._next_id(), "method": "tools/call", "params": { "name": tool_name, "arguments": arguments } } response = await self._send_request(message) if "result" in response: # Extract result from response result = response["result"] # Handle different response formats if isinstance(result, dict): # New format with 'result' field if "result" in result: return result["result"] # Content array format elif "content" in result: content = result["content"] if isinstance(content, list) and len(content) > 0: return content[0].get("text", str(content)) return str(content) else: return str(result) else: return str(result) else: error = response.get("error", {}) raise Exception(f"Tool call failed: {error.get('message', response)}") async def close(self): """Close the MCP client and terminate server""" if self.process: logger.info(f"Closing MCP server: {self.server_name}") self.process.terminate() try: self.process.wait(timeout=5) except subprocess.TimeoutExpired: self.process.kill() self.process.wait() self.process = None self._initialized = False class MCPClientManager: """Manage multiple MCP clients""" def __init__(self): self.clients: Dict[str, MCPClient] = {} async def add_server(self, name: str, script_path: str): """Add and start an MCP server""" client = MCPClient(script_path, name) await client.start() self.clients[name] = client logger.info(f"Added MCP server: {name}") async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> str: """Call a tool on a specific server""" if server_name not in self.clients: raise ValueError(f"Server not found: {server_name}") return await self.clients[server_name].call_tool(tool_name, arguments) async def list_all_tools(self) -> Dict[str, List[Dict[str, Any]]]: """List tools from all servers, handling failures gracefully""" all_tools = {} failed_servers = [] for name, client in self.clients.items(): try: tools = await client.list_tools() for tool in tools: tool['server'] = name # Add server info to each tool all_tools[name] = tools except Exception as e: logger.error(f"Failed to list tools from server {name}: {e}") failed_servers.append(name) # Continue with other servers instead of failing entirely all_tools[name] = [] if failed_servers: logger.warning(f"Some servers failed to respond: {', '.join(failed_servers)}") # Try to restart failed servers for server_name in failed_servers: try: client = self.clients[server_name] script_path = client.script_path if hasattr(client, 'script_path') else None if script_path: logger.info(f"Attempting to restart {server_name} server...") await client.close() # Re-add the server (which will restart it) await self.add_server(server_name, script_path) # Try listing tools again after restart tools = await self.clients[server_name].list_tools() for tool in tools: tool['server'] = server_name all_tools[server_name] = tools logger.info(f"Successfully restarted {server_name} server") except Exception as restart_error: logger.error(f"Failed to restart {server_name}: {restart_error}") # Remove the failed server from clients to prevent further errors if server_name in self.clients: del self.clients[server_name] return all_tools async def close_all(self): """Close all MCP clients""" for client in self.clients.values(): await client.close() self.clients.clear() logger.info("All MCP servers closed")