#!/usr/bin/env python3
"""
HTTP MCP client for exercising the recommender server tools.

This replaces the old LangChain + Ollama chatbot with a lightweight API tester
so downstream integrators can script calls to `get_recommendations` and
`submit_feedback` directly.
"""

from __future__ import annotations

import argparse
import asyncio
import json
import os
from typing import Any, Dict, Optional

import httpx

DEFAULT_SERVER_URL = os.getenv("MCP_SERVER_URL", "https://mcp-server.workzy.co")


class LocalMCPClient:
    """Minimal HTTP client for FastMCP servers."""

    def __init__(self, base_url: str, timeout: float = 30.0, headers: Optional[Dict[str, str]] = None):
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout
        self.headers = headers or {}
        self.http_client: Optional[httpx.AsyncClient] = None
        self.server_info: Dict[str, Any] = {}

    async def __aenter__(self):
        await self.connect()
        return self

    async def __aexit__(self, exc_type, exc, tb):
        await self.disconnect()

    async def connect(self) -> None:
        if self.http_client:
            return
        self.http_client = httpx.AsyncClient(timeout=self.timeout, headers=self.headers)
        try:
            resp = await self.http_client.get(f"{self.base_url}/")
            if resp.status_code == 200:
                self.server_info = resp.json()
        except Exception:
            self.server_info = {}

    async def disconnect(self) -> None:
        if self.http_client:
            await self.http_client.aclose()
            self.http_client = None

    async def list_tools(self) -> Any:
        if not self.http_client:
            raise RuntimeError("Client is not connected")
        resp = await self.http_client.get(f"{self.base_url}/tools")
        resp.raise_for_status()
        return resp.json()

    async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
        if not self.http_client:
            raise RuntimeError("Client is not connected")
        # fastmcp_http expects the tool name at the root and arguments flattened (not nested under "arguments")
        payload = {"name": name, **(arguments or {})}
        resp = await self.http_client.post(f"{self.base_url}/tools/call_tool", json=payload)
        resp.raise_for_status()
        data = resp.json()
        return self._interpret_response(data)

    @staticmethod
    def _interpret_response(data: Any) -> Any:
        if isinstance(data, list):
            texts = []
            for block in data:
                if isinstance(block, dict) and "text" in block:
                    texts.append(block["text"])
            if texts:
                combined = "\n".join(texts)
                try:
                    return json.loads(combined)
                except json.JSONDecodeError:
                    return combined
        return data


async def handle_list_tools(client: LocalMCPClient) -> None:
    tools = await client.list_tools()
    print(json.dumps(tools, indent=2))


async def handle_call_tool(client: LocalMCPClient, name: str, payload: str | None) -> None:
    arguments = json.loads(payload) if payload else {}
    result = await client.call_tool(name, arguments)
    if isinstance(result, (dict, list)):
        print(json.dumps(result, indent=2))
    else:
        print(result)


async def handle_recommend(client: LocalMCPClient, args: argparse.Namespace) -> None:
    algorithm_params = {}
    if args.epsilon is not None:
        algorithm_params["epsilon"] = args.epsilon
    payload = {
        "request": {
            "user_id": args.user_id,
            "n_recommendations": args.count,
            "algorithm_type": args.algorithm,
            "algorithm_params": algorithm_params or None,
        }
    }
    result = await client.call_tool("get_recommendations", payload)
    print(json.dumps(result, indent=2))


async def handle_feedback(client: LocalMCPClient, args: argparse.Namespace) -> None:
    payload = {
        "feedback": {
            "user_id": args.user_id,
            "item_id": args.item_id,
            "event_type": args.event,
            "timestamp": args.timestamp,
        }
    }
    result = await client.call_tool("submit_feedback", payload)
    print(json.dumps(result, indent=2))


def _build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Interact with the recommender MCP HTTP server.")
    parser.add_argument(
        "--server-url",
        default=DEFAULT_SERVER_URL,
        help="Base URL of the FastMCP HTTP server.",
    )
    parser.add_argument(
        "--auth-token",
        default=os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN"),
        help="Optional bearer token for private endpoints (e.g., private HF Space).",
    )

    subparsers = parser.add_subparsers(dest="command", required=True)

    subparsers.add_parser("list-tools", help="List every tool exposed by the MCP server.")

    call_parser = subparsers.add_parser("call-tool", help="Call any tool using a JSON payload.")
    call_parser.add_argument("--name", required=True, help="Tool name (e.g., get_recommendations).")
    call_parser.add_argument("--payload", help="JSON payload string passed directly to the tool.")

    recommend_parser = subparsers.add_parser("recommend", help="Call get_recommendations with friendly flags.")
    recommend_parser.add_argument("--user-id", type=int, required=True)
    recommend_parser.add_argument("--count", type=int, default=3, help="Number of recommendations to request.")
    recommend_parser.add_argument(
        "--algorithm",
        default="epsilon_greedy",
        help="Algorithm name passed through to the recommender.",
    )
    recommend_parser.add_argument(
        "--epsilon",
        type=float,
        help="Optional epsilon value for epsilon-greedy style algorithms.",
    )

    feedback_parser = subparsers.add_parser("feedback", help="Send a submit_feedback event.")
    feedback_parser.add_argument("--user-id", type=int, required=True)
    feedback_parser.add_argument("--item-id", type=int, required=True)
    feedback_parser.add_argument(
        "--event",
        required=True,
        choices=["view", "cart", "purchase", "remove_from_cart"],
        help="Event type recorded as feedback.",
    )
    feedback_parser.add_argument("--timestamp", help="Optional ISO timestamp.")

    return parser


async def _dispatch(args: argparse.Namespace, headers: Optional[Dict[str, str]]) -> None:
    async with LocalMCPClient(args.server_url, headers=headers) as client:
        if args.command == "list-tools":
            await handle_list_tools(client)
        elif args.command == "call-tool":
            await handle_call_tool(client, args.name, args.payload)
        elif args.command == "recommend":
            await handle_recommend(client, args)
        elif args.command == "feedback":
            await handle_feedback(client, args)
        else:  # pragma: no cover
            raise ValueError(f"Unknown command {args.command}")


def main() -> int:
    parser = _build_parser()
    args = parser.parse_args()
    headers = {"Authorization": f"Bearer {args.auth_token}"} if args.auth_token else None
    try:
        asyncio.run(_dispatch(args, headers=headers))
    except KeyboardInterrupt:
        print("\nCancelled by user.")
        return 130
    except Exception as exc:
        print(f"❌ Error: {exc!r}")
        return 1
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
