#!/usr/bin/env python3
"""
Reusable SSE MCP client tailored for the remote chatbot data tools.
"""

from __future__ import annotations

import asyncio
import json
import os
from typing import Any, Dict, List, Optional

import httpx

# Default configuration for the hosted chatbot MCP server
DEFAULT_MCP_URL = os.getenv("MCP_URL", "https://api.recomai.one/mcp")
DEFAULT_AUTH_KEY = os.getenv("MCP_AUTH_KEY", "+FiDOuHBrVo4X0SfW5KrEQ==:y0R6zQZbj7K6wXN5jxgmTg==")
DEFAULT_CHATBOT_KEY = os.getenv(
    "MCP_CHATBOT_KEY",
    "c4292acede69fc72e5b06bdea37b2cee45a4ac2190cb8c98e8b2dce6b53f9cee",
)


class MCPClientError(RuntimeError):
    """Generic error raised by the MCPSSEClient."""


class MCPSSEClient:
    """Minimal SSE transport client for MCP servers such as api.recomai.one."""

    def __init__(
        self,
        base_url: str = DEFAULT_MCP_URL,
        auth_key: str = DEFAULT_AUTH_KEY,
        chatbot_key: Optional[str] = DEFAULT_CHATBOT_KEY,
        timeout: float = 30.0,
    ):
        base_url = base_url.rstrip("/")
        separator = "&" if "?" in base_url else "?"
        self.url = f"{base_url}{separator}authKey={auth_key}"
        self.base_url = base_url
        self.auth_key = auth_key
        self.chatbot_key = chatbot_key or auth_key
        self.timeout = timeout

        self.request_id = 0
        self.http_client: Optional[httpx.AsyncClient] = None
        self.server_info: Dict[str, Any] = {}
        self.capabilities: Dict[str, Any] = {}
        self.mcp_tools: List[Dict[str, Any]] = []
        self.initialized = False

    async def __aenter__(self):
        await self.connect()
        return self

    async def __aexit__(self, exc_type, exc, tb):
        await self.disconnect()

    def _next_id(self) -> int:
        self.request_id += 1
        return self.request_id

    @staticmethod
    def _parse_sse_response(text: str) -> Optional[Dict[str, Any]]:
        lines = text.strip().splitlines()
        data_lines = [line.replace("data: ", "", 1) for line in lines if line.startswith("data: ")]
        if not data_lines:
            return None
        return json.loads(data_lines[0])

    async def _send_request(self, method: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        if not self.http_client:
            raise MCPClientError("Client is not connected")

        payload = {
            "jsonrpc": "2.0",
            "id": self._next_id(),
            "method": method,
            "params": params or {},
        }
        headers = {
            "Accept": "application/json, text/event-stream",
            "Content-Type": "application/json",
        }
        response = await self.http_client.post(self.url, json=payload, headers=headers)
        if response.status_code != 200:
            raise MCPClientError(f"HTTP {response.status_code}: {response.text}")

        parsed = self._parse_sse_response(response.text)
        if parsed is None:
            raise MCPClientError(f"Malformed SSE payload: {response.text}")
        if "error" in parsed:
            error = parsed["error"]
            raise MCPClientError(f"MCP error {error.get('code')}: {error.get('message')}")
        return parsed.get("result", {})

    async def connect(self) -> None:
        if self.initialized:
            return

        self.http_client = httpx.AsyncClient(timeout=self.timeout)
        result = await self._send_request(
            "initialize",
            {
                "protocolVersion": "2024-11-05",
                "capabilities": {"roots": {}, "sampling": {}},
                "clientInfo": {"name": "recommender-data-loader", "version": "0.2.0"},
            },
        )
        self.server_info = result.get("serverInfo", {})
        self.capabilities = result.get("capabilities", {})
        self.initialized = True

    async def list_tools(self) -> List[Dict[str, Any]]:
        if not self.initialized:
            raise MCPClientError("connect() must be called before listing tools")
        result = await self._send_request("tools/list", {})
        self.mcp_tools = result.get("tools", [])
        return self.mcp_tools

    async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
        if not self.initialized:
            raise MCPClientError("connect() must be called before using tools")
        return await self._send_request("tools/call", {"name": tool_name, "arguments": arguments})

    async def call_tool_json(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
        result = await self.call_tool(tool_name, arguments)
        text = self.extract_text_content(result)
        if not text:
            return result
        try:
            return json.loads(text)
        except json.JSONDecodeError as exc:
            raise MCPClientError(f"Tool {tool_name} returned non-JSON content: {text[:200]}") from exc

    async def disconnect(self) -> None:
        if self.http_client:
            await self.http_client.aclose()
            self.http_client = None
        self.initialized = False

    @staticmethod
    def extract_text_content(result: Dict[str, Any]) -> Optional[str]:
        content = result.get("content")
        if not isinstance(content, list):
            return None
        for block in content:
            if not isinstance(block, dict):
                continue
            text = block.get("text")
            if text:
                return text
            if block.get("type") == "application/json":
                data = block.get("data")
                if isinstance(data, str):
                    return data
        return None


async def main() -> None:
    """Small CLI helper for manual testing."""
    import argparse
    import os

    parser = argparse.ArgumentParser(description="Call an MCP SSE tool directly.")
    parser.add_argument("--tool", required=True, help="Tool name to invoke.")
    parser.add_argument("--payload", default="{}", help="JSON string payload.")
    args = parser.parse_args()

    payload = json.loads(args.payload or "{}")

    client = MCPSSEClient(
        base_url=os.getenv("MCP_URL", "https://api.recomai.one/mcp"),
        auth_key=os.getenv("MCP_AUTH_KEY", "+FiDOuHBrVo4X0SfW5KrEQ==:y0R6zQZbj7K6wXN5jxgmTg=="),
        chatbot_key=os.getenv("MCP_CHATBOT_KEY", "c4292acede69fc72e5b06bdea37b2cee45a4ac2190cb8c98e8b2dce6b53f9cee"),
    )

    try:
        await client.connect()
        result = await client.call_tool(args.tool, payload)
        text = client.extract_text_content(result)
        print(text or json.dumps(result, indent=2))
    finally:
        await client.disconnect()


if __name__ == "__main__":
    asyncio.run(main())

