"""
High-level orchestrator that downloads MCP data and stores it in the local DB.
"""

from __future__ import annotations

import asyncio
import os
from typing import Any, Dict, List, Optional

from db import DataController
from db.models import ProductCatalog, UserDemographics
from .mcp_client import (
    MCPSSEClient,
    DEFAULT_MCP_URL,
    DEFAULT_AUTH_KEY,
    DEFAULT_CHATBOT_KEY,
)
from .normalizers import normalize_customer, normalize_product

DEFAULT_DB_PATH = os.getenv("RECS_DB_PATH", "recommender_system.db")


class MCPDataSync:
    """Syncs remote chatbot data (products/customers/orders) into the local DB."""

    def __init__(
        self,
        db_path: str = DEFAULT_DB_PATH,
        mcp_url: str = DEFAULT_MCP_URL,
        auth_key: str = DEFAULT_AUTH_KEY,
        chatbot_key: str = DEFAULT_CHATBOT_KEY,
    ):
        self.db_path = db_path
        self.mcp_url = mcp_url
        self.auth_key = auth_key
        self.chatbot_key = chatbot_key
        self.data_controller = DataController(db_path)

    # -----------------------------
    # Public entry points
    # -----------------------------
    def sync_products(
        self,
        page_size: int = 250,
        max_pages: Optional[int] = None,
        search: Optional[str] = None,
        dry_run: bool = False,
    ) -> Dict[str, Any]:
        return asyncio.run(
            self._sync_products_async(page_size=page_size, max_pages=max_pages, search=search, dry_run=dry_run)
        )

    def sync_customers(
        self,
        batch_size: int = 250,
        max_batches: Optional[int] = None,
    ) -> Dict[str, Any]:
        return asyncio.run(self._sync_customers_async(batch_size=batch_size, max_batches=max_batches))

    def sync_all(
        self,
        *,
        sync_products: bool = True,
        sync_customers: bool = True,
    ) -> Dict[str, Any]:
        return asyncio.run(self._sync_all_async(sync_products=sync_products, sync_customers=sync_customers))

    # -----------------------------
    # Async implementations
    # -----------------------------
    async def _sync_all_async(self, sync_products: bool, sync_customers: bool) -> Dict[str, Any]:
        results: Dict[str, Any] = {}
        async with MCPSSEClient(self.mcp_url, self.auth_key, self.chatbot_key) as client:
            if sync_products:
                results["products"] = await self._sync_products_with_client(
                    client,
                    page_size=250,
                    max_pages=None,
                    search=None,
                    dry_run=False,
                )
            if sync_customers:
                results["customers"] = await self._sync_customers_with_client(client, batch_size=250, max_batches=None)
        return results

    async def _sync_products_async(
        self,
        page_size: int,
        max_pages: Optional[int],
        search: Optional[str],
        dry_run: bool,
    ) -> Dict[str, Any]:
        async with MCPSSEClient(self.mcp_url, self.auth_key, self.chatbot_key) as client:
            return await self._sync_products_with_client(
                client,
                page_size=page_size,
                max_pages=max_pages,
                search=search,
                dry_run=dry_run,
            )

    async def _sync_customers_async(self, batch_size: int, max_batches: Optional[int]) -> Dict[str, Any]:
        async with MCPSSEClient(self.mcp_url, self.auth_key, self.chatbot_key) as client:
            return await self._sync_customers_with_client(client, batch_size=batch_size, max_batches=max_batches)

    # -----------------------------
    # Core sync helpers
    # -----------------------------
    async def _sync_products_with_client(
        self,
        client: MCPSSEClient,
        *,
        page_size: int,
        max_pages: Optional[int],
        search: Optional[str],
        dry_run: bool = False,
    ) -> Dict[str, Any]:
        total_remote = 0
        total_normalized = 0
        total_persisted = 0
        page = 1

        while True:
            payload: Dict[str, Any] = {
                "chatbotKey": self.chatbot_key,
                "page": page,
                "limit": page_size,
            }
            if search:
                payload["search"] = search

            response = await client.call_tool_json("get_products", payload)
            data = self._extract_data(response)
            products = self._extract_products(data)

            if not products:
                break

            total_remote += len(products)
            normalized = _normalize_products(products)
            total_normalized += len(normalized)

            if not dry_run:
                for product in normalized:
                    if self.data_controller.add_product(product):
                        total_persisted += 1
            else:
                total_persisted += len(normalized)

            if len(products) < page_size:
                break
            if max_pages and page >= max_pages:
                break
            page += 1

        return {
            "pages_processed": page if total_remote else 0,
            "remote_items": total_remote,
            "normalized": total_normalized,
            "persisted": total_persisted,
            "dry_run": dry_run,
        }

    async def _sync_customers_with_client(
        self,
        client: MCPSSEClient,
        *,
        batch_size: int,
        max_batches: Optional[int],
    ) -> Dict[str, Any]:
        total_remote = 0
        inserted = 0
        batches = 0
        cursor: Optional[str] = None

        while True:
            payload: Dict[str, Any] = {"chatbotKey": self.chatbot_key, "first": batch_size}
            if cursor:
                payload["after"] = cursor

            response = await client.call_tool_json("get_customers", payload)
            data = self._extract_data(response)
            customers = data.get("customers", []) if isinstance(data, dict) else []
            page_info = data.get("pageInfo", {}) if isinstance(data, dict) else {}

            if not customers:
                break

            total_remote += len(customers)
            for customer in customers:
                user_model = normalize_customer(customer)
                if not self._user_exists(user_model):
                    self.data_controller.add_user(user_model)
                    inserted += 1

            batches += 1
            if not page_info.get("hasNextPage"):
                break
            cursor = page_info.get("endCursor")
            if max_batches and batches >= max_batches:
                break

        return {
            "batches_processed": batches,
            "remote_items": total_remote,
            "inserted": inserted,
        }

    # -----------------------------
    # Helper utilities
    # -----------------------------
    def _user_exists(self, user: UserDemographics) -> bool:
        if not user.name:
            return False
        existing = self.data_controller.get_user_by_name(user.name)
        return existing is not None

    @staticmethod
    def _extract_data(response: Dict[str, Any]) -> Dict[str, Any]:
        if not isinstance(response, dict):
            return {}
        if "output" in response and isinstance(response["output"], dict):
            data = response["output"].get("data")
            if isinstance(data, dict):
                return data
            return response["output"]
        if "data" in response and isinstance(response["data"], dict):
            return response["data"]
        return response

    @staticmethod
    def _extract_products(data: Dict[str, Any]) -> List[Dict[str, Any]]:
        products = data.get("products")
        if isinstance(products, list):
            return [p for p in products if isinstance(p, dict)]
        if isinstance(products, dict):
            edges = products.get("edges")
            if isinstance(edges, list):
                normalized: List[Dict[str, Any]] = []
                for edge in edges:
                    if isinstance(edge, dict):
                        node = edge.get("node") or edge
                        if isinstance(node, dict):
                            normalized.append(node)
                return normalized
        return []


def _normalize_products(products: List[Dict[str, Any]]) -> List[ProductCatalog]:
    normalized: List[ProductCatalog] = []
    for product in products:
        normalized_product = normalize_product(product)
        if normalized_product:
            normalized.append(normalized_product)
    return normalized
