"""
Recommender MCP Server (HTTP transport via fastmcp_http)
--------------------------------------------------------
- Exposes Model Context Protocol tools backed by the in-memory RL recommender.
- Served over plain HTTP so LangChain/Ollama clients can talk to it easily.

Run locally:
    python mcp_server.py  # binds on $PORT or 7860
Then connect from an MCP-compatible client (e.g., the provided Ollama chatbot).
"""
from __future__ import annotations

from typing import Any, Dict, List, Optional
import json
import time

from pydantic import BaseModel, Field
from fastmcp_http.server import FastMCPHttpServer
from mcp.types import TextContent

from db import DataController
from db.models import FeedbackDB
from rl_recommender.UnifiedRecommendationEnv import UnifiedRecommendationEnv
from rl_recommender.simulation import MABSimulation
from data_sync.sync_service import MCPDataSync
from data_sync.mcp_client import DEFAULT_MCP_URL, DEFAULT_AUTH_KEY, DEFAULT_CHATBOT_KEY
import pandas as pd
import numpy as np
import os
from datetime import datetime
from threading import Lock, Thread

# ----------------------------
# Server
# ----------------------------
server = FastMCPHttpServer("recommender-mcp", description="HTTP MCP server for the RL recommender")

# Flask app (for health + app export)
_FLASK_APP = getattr(server, "flask_app", None)

if _FLASK_APP is not None:
    @_FLASK_APP.route("/", methods=["GET"])
    def _root():
        return {
            "status": "ok",
            "server": "recommender-mcp",
            "transport": "http",
        }, 200

MCP_SERVER_PORT = int(os.getenv("PORT", os.getenv("MCP_SERVER_PORT", "7860")))

# --- DB-only initialization ---
DB_PATH = os.getenv("RECS_DB_PATH", "recommender_system.db")
data_controller = DataController(DB_PATH)

# Auto-sync configuration
AUTO_SYNC_ENABLED = bool(int(os.getenv("RECS_AUTO_SYNC", "0")))
AUTO_SYNC_INTERVAL_SEC = max(
    300, int(os.getenv("RECS_SYNC_INTERVAL_MINUTES", "720")) * 60  # default 12 hours, min 5 minutes
)
AUTO_SYNC_PRODUCTS = bool(int(os.getenv("RECS_SYNC_PRODUCTS", "1")))
AUTO_SYNC_CUSTOMERS = bool(int(os.getenv("RECS_SYNC_CUSTOMERS", "1")))
AUTO_SYNC_PAGE_SIZE = int(os.getenv("RECS_SYNC_PAGE_SIZE", "250"))
AUTO_SYNC_MAX_PAGES = os.getenv("RECS_SYNC_MAX_PAGES")
AUTO_SYNC_MAX_PAGES = int(AUTO_SYNC_MAX_PAGES) if AUTO_SYNC_MAX_PAGES else None
AUTO_SYNC_CUSTOMER_BATCH_SIZE = int(os.getenv("RECS_SYNC_CUSTOMER_BATCH_SIZE", "250"))
AUTO_SYNC_CUSTOMER_MAX_BATCHES = os.getenv("RECS_SYNC_CUSTOMER_MAX_BATCHES")
AUTO_SYNC_CUSTOMER_MAX_BATCHES = (
    int(AUTO_SYNC_CUSTOMER_MAX_BATCHES) if AUTO_SYNC_CUSTOMER_MAX_BATCHES else None
)
INITIAL_SYNC_ENABLED = bool(int(os.getenv("RECS_SYNC_ON_START", "1")))

# RL env (DB mode)
_ENV = UnifiedRecommendationEnv(
    data_source={"db_path": DB_PATH},
    max_arms=int(os.getenv("RECS_MAX_ARMS", "50")),
    max_steps=int(os.getenv("RECS_MAX_STEPS", "100")),
    n_suggestions=int(os.getenv("RECS_N_SUGGESTIONS", "3")),
    seed=None,
    use_clustering=bool(int(os.getenv("RECS_USE_CLUSTERING", "0"))),
    use_user_context=bool(int(os.getenv("RECS_USE_USER_CONTEXT", "0"))),
    use_price_optimization=bool(int(os.getenv("RECS_USE_PRICE_OPT", "0"))),
)

# Initialize MABSimulation with the environment - THIS IS CRUCIAL
_SIMULATION = MABSimulation(
    environment=_ENV,
    algorithm_type="epsilon_greedy",
    algorithm_params={"epsilon": 0.1}
)

_PRODUCT_ID_TO_ARM: Dict[int, int] = {}
_ARM_TO_PRODUCT_ID: List[int] = []
_MAPPING_LOCK = Lock()
_DEFAULT_ALGO_PARAMS: Dict[str, Any] = {"epsilon": 0.1}

def _refresh_product_arm_mapping(force: bool = False) -> None:
    """
    Ensure the product-id<->arm mappings are up to date.
    """
    global _PRODUCT_ID_TO_ARM, _ARM_TO_PRODUCT_ID

    with _MAPPING_LOCK:
        if _PRODUCT_ID_TO_ARM and not force:
            return

        try:
            products = data_controller.get_all_products(limit=_ENV.max_arms)
            product_ids = [int(p["product_id"]) for p in products if "product_id" in p]
        except Exception:
            product_ids = []

        _ARM_TO_PRODUCT_ID = product_ids[: _ENV.max_arms]
        _PRODUCT_ID_TO_ARM = {pid: idx for idx, pid in enumerate(_ARM_TO_PRODUCT_ID)}

def _product_id_to_arm(pid: int) -> Optional[int]:
    """Map a product ID to its bandit arm index."""
    if pid not in _PRODUCT_ID_TO_ARM:
        _refresh_product_arm_mapping(force=True)
    return _PRODUCT_ID_TO_ARM.get(pid)

def _arm_to_product_id(arm_idx: int) -> Optional[int]:
    """Map an arm index back to the real product ID."""
    if not _ARM_TO_PRODUCT_ID:
        _refresh_product_arm_mapping()
    if 0 <= arm_idx < len(_ARM_TO_PRODUCT_ID):
        return _ARM_TO_PRODUCT_ID[arm_idx]
    return None

def _available_products_from_db(limit: int) -> List[int]:
    """Get available product IDs from database."""
    _refresh_product_arm_mapping(force=True)
    return _ARM_TO_PRODUCT_ID[:limit]


def _as_content(payload: Dict[str, Any]) -> List[TextContent]:
    """Wrap a dict payload into MCP TextContent for HTTP transport."""
    return [TextContent(type="text", text=json.dumps(payload))]


def _ensure_algorithm_matches_inventory(
    algo_type: str, algo_params: Dict[str, Any], arm_candidates: List[int]
) -> None:
    """
    Keep algorithm/env arm counts in sync with available products.
    Reinitializes the algorithm when the arm space size changes.
    """
    target_arms = max(1, len(arm_candidates))

    # Keep env in sync so downstream mappings honor the current inventory size.
    if getattr(_ENV, "max_arms", None) != target_arms:
        _ENV.max_arms = target_arms

    algo = _SIMULATION.current_algorithm
    algo_n_arms = getattr(algo, "n_arms", None) if algo else None
    needs_reset = algo is None or algo_n_arms != target_arms

    # Preserve existing algo params if none are provided
    merged_params = _DEFAULT_ALGO_PARAMS.copy()
    if _SIMULATION.algorithm_params:
        merged_params.update(_SIMULATION.algorithm_params)
    if algo_params:
        merged_params.update(algo_params)

    # Reset when arm space changes, algo type changes, or caller passes params
    if needs_reset or algo_type != _SIMULATION.algorithm_type or algo_params:
        _SIMULATION.set_algorithm(algo_type, merged_params)


def _initial_sync_if_needed() -> None:
    """
    Perform a one-time sync on startup if enabled and product table is empty.
    """
    if not INITIAL_SYNC_ENABLED:
        return

    try:
        product_count = data_controller.get_product_count()
    except Exception:
        product_count = 0

    if product_count > 0:
        return

    mcp_url = os.getenv("MCP_URL", DEFAULT_MCP_URL)
    auth_key = os.getenv("MCP_AUTH_KEY", DEFAULT_AUTH_KEY)
    chatbot_key = os.getenv("MCP_CHATBOT_KEY", DEFAULT_CHATBOT_KEY)

    try:
        syncer = MCPDataSync(
            db_path=DB_PATH,
            mcp_url=mcp_url,
            auth_key=auth_key,
            chatbot_key=chatbot_key,
        )
        if AUTO_SYNC_PRODUCTS:
            syncer.sync_products(
                page_size=AUTO_SYNC_PAGE_SIZE,
                max_pages=AUTO_SYNC_MAX_PAGES,
                search=None,
                dry_run=False,
            )
        if AUTO_SYNC_CUSTOMERS:
            syncer.sync_customers(
                batch_size=AUTO_SYNC_CUSTOMER_BATCH_SIZE,
                max_batches=AUTO_SYNC_CUSTOMER_MAX_BATCHES,
            )
        print("✅ Initial data sync completed (startup)")
    except Exception as exc:
        print(f"⚠️  Initial data sync failed: {exc}")


def _start_background_sync() -> None:
    """Kick off background sync loop if enabled via env."""
    if not AUTO_SYNC_ENABLED:
        return

    mcp_url = os.getenv("MCP_URL", DEFAULT_MCP_URL)
    auth_key = os.getenv("MCP_AUTH_KEY", DEFAULT_AUTH_KEY)
    chatbot_key = os.getenv("MCP_CHATBOT_KEY", DEFAULT_CHATBOT_KEY)

    def _sync_loop():
        syncer = MCPDataSync(
            db_path=DB_PATH,
            mcp_url=mcp_url,
            auth_key=auth_key,
            chatbot_key=chatbot_key,
        )
        while True:
            try:
                if AUTO_SYNC_PRODUCTS:
                    syncer.sync_products(
                        page_size=AUTO_SYNC_PAGE_SIZE,
                        max_pages=AUTO_SYNC_MAX_PAGES,
                        search=None,
                        dry_run=False,
                    )
                if AUTO_SYNC_CUSTOMERS:
                    syncer.sync_customers(
                        batch_size=AUTO_SYNC_CUSTOMER_BATCH_SIZE,
                        max_batches=AUTO_SYNC_CUSTOMER_MAX_BATCHES,
                    )
            except Exception as exc:
                print(f"⚠️  Auto-sync failed: {exc}")
            time.sleep(AUTO_SYNC_INTERVAL_SEC)

    try:
        t = Thread(target=_sync_loop, daemon=True)
        t.start()
        print(
            f"🔄 Auto-sync enabled: products={AUTO_SYNC_PRODUCTS}, customers={AUTO_SYNC_CUSTOMERS}, "
            f"interval={AUTO_SYNC_INTERVAL_SEC}s"
        )
    except Exception as exc:
        print(f"⚠️  Failed to start auto-sync thread: {exc}")


# Start background sync (no-op if disabled)
_initial_sync_if_needed()
_start_background_sync()

# ----------------------------
# Models (schemas auto-exposed in MCP)
# ----------------------------
class RecommendationRequest(BaseModel):
    user_id: int = Field(description="Unique user id")
    n_recommendations: int = Field(default=3, ge=1, le=50, description="How many items to recommend")
    algorithm_type: str = Field(default="epsilon_greedy", description="Bandit / policy type")
    algorithm_params: Optional[Dict[str, Any]] = Field(default=None, description="Algorithm hyperparameters")

class FeedbackData(BaseModel):
    user_id: int
    item_id: int
    event_type: str = Field(description="One of: view, cart, purchase, remove_from_cart")
    timestamp: Optional[str] = None

# ----------------------------
# Tools
# ----------------------------

server.tool(
    "status",
    description="Get the current status of the recommender system",
    structured_output=False,
)
async def status() -> List[TextContent]:
    """
    Get the current status of the recommender system.
    """
    return _as_content({
        "status": "success",
        "message": "Recommender system is running",

        "tools": [
            {
                "name": "get_system_status",
                "description": "Get the current status of the recommender system",
                "structured_output": False
            },
            {
                "name": "submit_feedback",
                "description": "Handles the MAB reward by feedback and updates the recommender simulation",
                "structured_output": False
            },
            {
                "name": "get_recommendations",
                "description": "Get personalized recommendations using MABSimulation recommender class",
                "structured_output": False
            },
        ]
    })

@server.tool(
    "submit_feedback",
    description="Handles the MAB reward by feedback and updates the recommender simulation",
    structured_output=False,
)
async def submit_feedback(feedback: FeedbackData) -> List[TextContent]:
    """
    Record a user interaction event (DB + MABSimulation update).
    This updates the MABSimulation recommender with the feedback.
    """
    try:
        pid = int(feedback.item_id)
        uid = int(feedback.user_id)

        arm_index = _product_id_to_arm(pid)
        if arm_index is None:
            return _as_content({"status": "error", "message": f"Product {pid} is not available as a bandit arm"})
        
        # Persist feedback in DB first
        fb = FeedbackDB(
            user_id=uid,
            product_id=pid,
            event_type=feedback.event_type,
            session=datetime.utcnow(),
            active_user=1,
        )
        ok = data_controller.add_feedback(fb)
        
        if not ok:
            return _as_content({"status": "error", "message": "Failed to save feedback to database"})
        
        # Update MABSimulation with the feedback event
        # Create event DataFrame for the simulation
        event_df = pd.DataFrame([{
            'user_id': uid,
            'item_id': arm_index,
            'event_type': feedback.event_type,
            'timestamp': pd.Timestamp.now() if feedback.timestamp is None else pd.Timestamp(feedback.timestamp)
        }])
        
        # Update the simulation's algorithm with the feedback
        _SIMULATION.update_algorithm(event_df)
        
        # Get updated stats from the algorithm
        reward_mapping = {"purchase": 1.0, "cart": 0.33, "view": 0.1, "remove_from_cart": -0.5}
        reward = reward_mapping.get(feedback.event_type, 0.0)
        
        # Get algorithm stats
        n = 0
        q = 0.0
        algorithm = _SIMULATION.current_algorithm
        if algorithm and hasattr(algorithm, 'arm_counts') and hasattr(algorithm, 'arm_rewards'):
            if arm_index < len(algorithm.arm_counts):
                n = int(algorithm.arm_counts[arm_index])
                if n > 0:
                    q = float(algorithm.arm_rewards[arm_index] / n)
        
        return _as_content({
            "status": "success",
            "item_id": pid,
            "user_id": uid,
            "reward": reward,
            "n": n,
            "q": q,
            "message": "Feedback processed and MABSimulation updated"
        })
    except Exception as e:
        return _as_content({"status": "error", "message": f"submit_feedback failed: {e}"})

@server.tool(
    "get_recommendations",
    description="Get personalized recommendations using MABSimulation recommender class",
    structured_output=False,
)
async def get_recommendations(request: RecommendationRequest) -> List[TextContent]:
    """
    Get personalized product recommendations using the MABSimulation recommender.
    This uses the UnifiedRecommendationEnv and MABSimulation classes for intelligent recommendations.
    """
    try:
        # Reset environment for the user
        try:
            obs, info = _ENV.reset(options={"user_id": request.user_id})
        except Exception:
            obs, info = _ENV.reset()
        
        # Get available products from database
        product_candidates = _available_products_from_db(limit=_ENV.max_arms)
        if not product_candidates:
            return _as_content({"status": "error", "message": "No available products in DB"})

        arm_candidates = [
            _product_id_to_arm(pid) for pid in product_candidates
        ]
        arm_candidates = [arm for arm in arm_candidates if arm is not None]
        if not arm_candidates:
            return _as_content({"status": "error", "message": "No valid bandit arms mapped from products"})

        # Set algorithm if different from current and ensure arm space matches inventory
        algo_type = (request.algorithm_type or "epsilon_greedy").lower()
        algo_params = request.algorithm_params or {}
        _ensure_algorithm_matches_inventory(algo_type, algo_params, arm_candidates)
        
        # Use MABSimulation to get recommendations
        n_recommendations = max(1, min(int(request.n_recommendations), len(arm_candidates)))
        
        # Get recommendations using the simulation's invoke_agent method
        recommended_arms = _SIMULATION.invoke_agent(
            algorithm=_SIMULATION.current_algorithm,
            user_id=request.user_id,
            n_recommendations=n_recommendations,
            input_arms_idx=arm_candidates,
            obs=obs
        )

        # Convert arm indices back to product IDs
        recommendations = [
            _arm_to_product_id(arm) for arm in recommended_arms
        ]
        recommendations = [pid for pid in recommendations if pid is not None]
        if not recommendations:
            return _as_content({"status": "error", "message": "Algorithm returned no valid product recommendations"})

        # Fetch product details for the client to render
        product_details: List[Dict[str, Any]] = []
        for pid in recommendations:
            detail = data_controller.get_product(pid)
            if detail:
                product_details.append(detail)
            else:
                product_details.append({"product_id": pid, "missing": True})
        
        # Calculate confidence scores from algorithm state
        confidence_scores = []
        algorithm = _SIMULATION.current_algorithm
        if algorithm and hasattr(algorithm, 'arm_rewards') and hasattr(algorithm, 'arm_counts'):
            with np.errstate(divide='ignore', invalid='ignore'):
                avg_rewards = np.divide(
                    algorithm.arm_rewards,
                    algorithm.arm_counts,
                    out=np.zeros_like(algorithm.arm_rewards),
                    where=algorithm.arm_counts != 0
                )
            for pid in recommendations:
                arm_idx = _product_id_to_arm(pid)
                if arm_idx is not None and arm_idx < len(avg_rewards):
                    confidence = max(0.1, min(1.0, float(avg_rewards[arm_idx])))
                    confidence_scores.append(confidence)
                else:
                    confidence_scores.append(0.5)
        else:
            confidence_scores = [0.7] * len(recommendations)

        return _as_content({
            "status": "success",
            "user_id": request.user_id,
            "recommendations": recommendations if isinstance(recommendations, list) else recommendations.tolist(),
            "products": product_details,
            "algorithm_type": _SIMULATION.algorithm_type,
            "algorithm_class": _SIMULATION.current_algorithm.__class__.__name__ if _SIMULATION.current_algorithm else None,
            "confidence_scores": confidence_scores,
            "total_recommendations": len(recommendations),
            "message": "Recommendations generated using MABSimulation"
        })
    except Exception as e:
        return _as_content({"status": "error", "message": f"get_recommendations failed: {e}"})

# ----------------------------
# HTTP App Export (for uvicorn/Spaces)
# ----------------------------

def http_app():
    """
    Expose the underlying Flask app so external servers (uvicorn, gunicorn)
    can host the FastMCP HTTP transport.
    """
    if _FLASK_APP is None:
        raise RuntimeError("FastMCPHttpServer did not expose a Flask app")
    return _FLASK_APP

# ----------------------------
# Main (stdio transport)
# ----------------------------

if __name__ == "__main__":
    print(f"🚀 Starting MCP HTTP server on :{MCP_SERVER_PORT}")
    server.run_http(host="0.0.0.0", register_server=False, port=MCP_SERVER_PORT)
