#!/usr/bin/env python3
"""
CLI helpers for calling remote MCP tools and syncing their data into the local DB.
"""

from __future__ import annotations

import argparse
import asyncio
import json
import os
import sqlite3
import sys
from pathlib import Path
from typing import Any, Dict

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from data_sync.mcp_client import MCPSSEClient, MCPClientError  # type: ignore  # noqa: E402
from data_sync.sync_service import MCPDataSync  # type: ignore  # noqa: E402

DEFAULT_MCP_URL = os.getenv("MCP_URL", "https://api.recomai.one/mcp")
DEFAULT_AUTH_KEY = os.getenv("MCP_AUTH_KEY", "+FiDOuHBrVo4X0SfW5KrEQ==:y0R6zQZbj7K6wXN5jxgmTg==")
DEFAULT_CHATBOT_KEY = os.getenv(
    "MCP_CHATBOT_KEY",
    "c4292acede69fc72e5b06bdea37b2cee45a4ac2190cb8c98e8b2dce6b53f9cee",
)
DEFAULT_DB_PATH = os.getenv("RECS_DB_PATH", "recommender_system.db")


def _load_payload(payload_str: str | None) -> Dict[str, Any]:
    if not payload_str:
        return {}
    try:
        return json.loads(payload_str)
    except json.JSONDecodeError as exc:
        raise SystemExit(f"Invalid JSON payload: {exc}") from exc


async def run_tool(tool: str, payload: Dict[str, Any]) -> None:
    """Connect to MCP, optionally list tools, or invoke one tool."""
    client = MCPSSEClient(DEFAULT_MCP_URL, DEFAULT_AUTH_KEY, DEFAULT_CHATBOT_KEY)
    payload = dict(payload)
    if DEFAULT_CHATBOT_KEY and "chatbotKey" not in payload:
        payload["chatbotKey"] = DEFAULT_CHATBOT_KEY

    try:
        await client.connect()
        if tool == "list-tools":
            tools = await client.list_tools()
            print(json.dumps(tools, indent=2))
            return

        result = await client.call_tool(tool, payload)
        text = client.extract_text_content(result)
        print(text or json.dumps(result, indent=2))
    finally:
        await client.disconnect()


def _db_counts(db_path: str) -> Dict[str, Any]:
    """Return lightweight row counts for key tables to verify persistence."""
    tables = ("product_catalog", "user_demographics", "feedback_db")
    counts: Dict[str, Any] = {}
    try:
        conn = sqlite3.connect(db_path)
        cur = conn.cursor()
        for table in tables:
            try:
                cur.execute(f"SELECT COUNT(*) FROM {table}")
                counts[table] = cur.fetchone()[0]
            except Exception as exc:  # table may not exist yet
                counts[table] = f"error: {exc}"
        conn.close()
    except Exception as exc:
        counts["error"] = str(exc)
    return counts


def sync_products_cli(args: argparse.Namespace) -> None:
    syncer = MCPDataSync(
        db_path=args.db_path,
        mcp_url=DEFAULT_MCP_URL,
        auth_key=DEFAULT_AUTH_KEY,
        chatbot_key=DEFAULT_CHATBOT_KEY,
    )
    stats = syncer.sync_products(
        page_size=args.page_size,
        max_pages=args.max_pages,
        search=args.search,
        dry_run=args.dry_run,
    )
    result: Dict[str, Any] = {"sync": stats}
    if args.verify and not args.dry_run:
        result["db_counts"] = _db_counts(args.db_path)
    elif args.verify and args.dry_run:
        result["db_counts"] = "skipped (dry-run)"
    print(json.dumps(result, indent=2))


def sync_customers_cli(args: argparse.Namespace) -> None:
    syncer = MCPDataSync(
        db_path=args.db_path,
        mcp_url=DEFAULT_MCP_URL,
        auth_key=DEFAULT_AUTH_KEY,
        chatbot_key=DEFAULT_CHATBOT_KEY,
    )
    stats = syncer.sync_customers(batch_size=args.batch_size, max_batches=args.max_batches)
    result: Dict[str, Any] = {"sync": stats}
    if args.verify:
        result["db_counts"] = _db_counts(args.db_path)
    print(json.dumps(result, indent=2))


def sync_all_cli(args: argparse.Namespace) -> None:
    syncer = MCPDataSync(
        db_path=args.db_path,
        mcp_url=DEFAULT_MCP_URL,
        auth_key=DEFAULT_AUTH_KEY,
        chatbot_key=DEFAULT_CHATBOT_KEY,
    )
    stats = syncer.sync_all(sync_products=not args.skip_products, sync_customers=not args.skip_customers)
    result: Dict[str, Any] = {"sync": stats}
    if args.verify:
        result["db_counts"] = _db_counts(args.db_path)
    print(json.dumps(result, indent=2))


def _build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Call MCP data tools or sync the local DB.")
    subparsers = parser.add_subparsers(dest="command", required=True)

    call_parser = subparsers.add_parser("call-tool", help="Invoke a single MCP tool and print the raw response.")
    call_parser.add_argument("--tool", required=True, help="Tool name to invoke (use 'list-tools' to inspect).")
    call_parser.add_argument(
        "--payload",
        help="JSON string with tool arguments (chatbotKey auto-injected if missing).",
    )

    sync_products_parser = subparsers.add_parser("sync-products", help="Fetch paginated products into the local DB.")
    sync_products_parser.add_argument("--page-size", type=int, default=250, help="Number of remote products per page.")
    sync_products_parser.add_argument("--max-pages", type=int, default=4, help="Maximum remote pages to fetch.")
    sync_products_parser.add_argument("--search", help="Optional remote search term.")
    sync_products_parser.add_argument("--dry-run", action="store_true", help="Run without writing to the database.")
    sync_products_parser.add_argument("--db-path", default=DEFAULT_DB_PATH, help="SQLite file to persist into.")
    sync_products_parser.add_argument(
        "--verify",
        action="store_true",
        help="Also print row counts from the target DB after sync (skipped on dry-run).",
    )

    sync_customers_parser = subparsers.add_parser("sync-customers", help="Import remote customers into the DB.")
    sync_customers_parser.add_argument("--batch-size", type=int, default=250, help="Remote page size.")
    sync_customers_parser.add_argument("--max-batches", type=int, help="Limit how many batches are fetched.")
    sync_customers_parser.add_argument("--db-path", default=DEFAULT_DB_PATH, help="SQLite file to persist into.")
    sync_customers_parser.add_argument(
        "--verify", action="store_true", help="Also print row counts from the target DB after sync."
    )

    sync_all_parser = subparsers.add_parser("sync-all", help="Sync both products and customers.")
    sync_all_parser.add_argument("--skip-products", action="store_true", help="Skip syncing products.")
    sync_all_parser.add_argument("--skip-customers", action="store_true", help="Skip syncing customers.")
    sync_all_parser.add_argument("--db-path", default=DEFAULT_DB_PATH, help="SQLite file to persist into.")
    sync_all_parser.add_argument(
        "--verify", action="store_true", help="Also print row counts from the target DB after sync."
    )

    return parser


def main() -> int:
    parser = _build_parser()
    args = parser.parse_args()

    try:
        if args.command == "call-tool":
            payload = _load_payload(args.payload)
            asyncio.run(run_tool(args.tool, payload))
        elif args.command == "sync-products":
            sync_products_cli(args)
        elif args.command == "sync-customers":
            sync_customers_cli(args)
        elif args.command == "sync-all":
            sync_all_cli(args)
    except KeyboardInterrupt:
        print("\nCancelled by user.", file=sys.stderr)
        return 130
    except MCPClientError as exc:
        print(f"❌ MCP error: {exc}", file=sys.stderr)
        return 1
    except Exception as exc:  # pragma: no cover - top-level guard rail
        print(f"❌ Error: {exc}", file=sys.stderr)
        return 1
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
