"""
Data Controller for the MAB Recommendation System.
Handles all ORM queries and database operations.
"""

import sqlite3
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional, Tuple
from .models import (
    FeedbackDB, ProductCatalog, UserDemographics, 
    SearchHistory, DatabaseManager
)


class DataController:
    """Handles all database operations and ORM queries."""
    
    def __init__(self, db_path: str = "recommender_system.db"):
        self.db_manager = DatabaseManager(db_path)
        self.db_manager.init_database()
    
    # ==================== FEEDBACK OPERATIONS ====================
    
    def add_feedback(self, feedback: FeedbackDB) -> bool:
        """Add user feedback to the database."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    INSERT INTO feedback_db (user_id, product_id, event_type, session, active_user)
                    VALUES (?, ?, ?, ?, ?)
                ''', (
                    feedback.user_id, feedback.product_id, feedback.event_type,
                    feedback.session, feedback.active_user
                ))
                conn.commit()
                conn.close()  # Explicitly close the connection
                return True
        except Exception as e:
            print(f"❌ Error adding feedback: {e}")
            return False
    
    def get_user_feedback(self, user_id: int, limit: int = 100) -> List[Dict[str, Any]]:
        """Get feedback history for a specific user."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    SELECT f.*, p.product_name, p.brand_name, p.category, p.price
                    FROM feedback_db f
                    JOIN product_catalog p ON f.product_id = p.product_id
                    WHERE f.user_id = ?
                    ORDER BY f.session DESC
                    LIMIT ?
                ''', (user_id, limit))
                
                results = []
                for row in cursor.fetchall():
                    results.append(dict(row))
                return results
        except Exception as e:
            print(f"❌ Error getting user feedback: {e}")
            return []
    
    def get_product_feedback(self, product_id: str, limit: int = 100) -> List[Dict[str, Any]]:
        """Get feedback history for a specific product."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    SELECT f.*, u.name as user_name
                    FROM feedback_db f
                    JOIN user_demographics u ON f.user_id = u.user_id
                    WHERE f.product_id = ?
                    ORDER BY f.session DESC
                    LIMIT ?
                ''', (product_id, limit))
                
                results = []
                for row in cursor.fetchall():
                    results.append(dict(row))
                return results
        except Exception as e:
            print(f"❌ Error getting product feedback: {e}")
            return []
    
    def get_recent_feedback(self, hours: int = 24) -> List[Dict[str, Any]]:
        """Get recent feedback within specified hours."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cutoff_time = datetime.now() - timedelta(hours=hours)
                cursor.execute('''
                    SELECT f.*, p.product_name, u.name as user_name
                    FROM feedback_db f
                    JOIN product_catalog p ON f.product_id = p.product_id
                    JOIN user_demographics u ON f.user_id = u.user_id
                    WHERE f.session >= ?
                    ORDER BY f.session DESC
                ''', (cutoff_time,))
                
                results = []
                for row in cursor.fetchall():
                    results.append(dict(row))
                return results
        except Exception as e:
            print(f"❌ Error getting recent feedback: {e}")
            return []
    
    # ==================== PRODUCT OPERATIONS ====================
    
    def add_product(self, product: ProductCatalog) -> bool:
        """Add a new product to the catalog."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    INSERT OR REPLACE INTO product_catalog 
                    (product_id, product_name, brand_name, category, price, description, image_url)
                    VALUES (?, ?, ?, ?, ?, ?, ?)
                ''', (
                    product.product_id, product.product_name, product.brand_name,
                    product.category, product.price, product.description, product.image_url
                ))
                conn.commit()
                return True
        except Exception as e:
            print(f"❌ Error adding product: {e}")
            return False
    
    def get_product(self, product_id: int) -> Optional[Dict[str, Any]]:
        """Get product details by ID."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    SELECT * FROM product_catalog WHERE product_id = ?
                ''', (product_id,))
                
                row = cursor.fetchone()
                return dict(row) if row else None
        except Exception as e:
            print(f"❌ Error getting product: {e}")
            return None
    
    def get_products_by_category(self, category: str, limit: int = 50) -> List[Dict[str, Any]]:
        """Get products by category."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    SELECT * FROM product_catalog 
                    WHERE category = ? 
                    ORDER BY product_id 
                    LIMIT ?
                ''', (category, limit))
                
                results = []
                for row in cursor.fetchall():
                    results.append(dict(row))
                return results
        except Exception as e:
            print(f"❌ Error getting products by category: {e}")
            return []
    
    def get_products_by_brand(self, brand: str, limit: int = 50) -> List[Dict[str, Any]]:
        """Get products by brand."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    SELECT * FROM product_catalog 
                    WHERE brand_name = ? 
                    ORDER BY product_id 
                    LIMIT ?
                ''', (brand, limit))
                
                results = []
                for row in cursor.fetchall():
                    results.append(dict(row))
                return results
        except Exception as e:
            print(f"❌ Error getting products by brand: {e}")
            return []
    
    def get_all_products(self, limit: int = 1000) -> List[Dict[str, Any]]:
        """Get all products with pagination."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    SELECT * FROM product_catalog 
                    ORDER BY product_id 
                    LIMIT ?
                ''', (limit,))
                
                results = []
                for row in cursor.fetchall():
                    results.append(dict(row))
                return results
        except Exception as e:
            print(f"❌ Error getting all products: {e}")
            return []
    
    # ==================== USER OPERATIONS ====================
    
    def add_user(self, user: UserDemographics) -> int:
        """Add a new user and return the user ID."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    INSERT INTO user_demographics (name, gender, age, user_cluster_id)
                    VALUES (?, ?, ?, ?)
                ''', (user.name, user.gender, user.age, user.user_cluster_id))
                
                user_id = cursor.lastrowid
                conn.commit()
                return user_id
        except Exception as e:
            print(f"❌ Error adding user: {e}")
            return -1
    
    def get_user(self, user_id: int) -> Optional[Dict[str, Any]]:
        """Get user details by ID."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    SELECT * FROM user_demographics WHERE user_id = ?
                ''', (user_id,))
                
                row = cursor.fetchone()
                return dict(row) if row else None
        except Exception as e:
            print(f"❌ Error getting user: {e}")
            return None
    
    def get_user_by_name(self, name: str) -> Optional[Dict[str, Any]]:
        """Lookup a user by name (used when importing external customers)."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute(
                    '''
                    SELECT * FROM user_demographics WHERE name = ?
                    ''',
                    (name,),
                )
                row = cursor.fetchone()
                return dict(row) if row else None
        except Exception as e:
            print(f"❌ Error getting user by name: {e}")
            return None

    def update_user_cluster(self, user_id: int, cluster_id: int) -> bool:
        """Update user's cluster assignment."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    UPDATE user_demographics 
                    SET user_cluster_id = ?, last_active = CURRENT_TIMESTAMP
                    WHERE user_id = ?
                ''', (cluster_id, user_id))
                conn.commit()
                return True
        except Exception as e:
            print(f"❌ Error updating user cluster: {e}")
            return False
    
    def get_users_by_cluster(self, cluster_id: int) -> List[Dict[str, Any]]:
        """Get all users in a specific cluster."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    SELECT * FROM user_demographics 
                    WHERE user_cluster_id = ?
                    ORDER BY user_id
                ''', (cluster_id,))
                
                results = []
                for row in cursor.fetchall():
                    results.append(dict(row))
                return results
        except Exception as e:
            print(f"❌ Error getting users by cluster: {e}")
            return []
    
    # ==================== SEARCH HISTORY OPERATIONS ====================
    
    def add_search_history(self, search: SearchHistory) -> bool:
        """Add search history entry."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    INSERT INTO search_history (user_id, query, product_id, session_id)
                    VALUES (?, ?, ?, ?)
                ''', (search.user_id, search.query, search.product_id, search.session_id))
                conn.commit()
                return True
        except Exception as e:
            print(f"❌ Error adding search history: {e}")
            return False
    
    def get_user_search_history(self, user_id: int, limit: int = 50) -> List[Dict[str, Any]]:
        """Get search history for a specific user."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    SELECT * FROM search_history 
                    WHERE user_id = ? 
                    ORDER BY timestamp DESC 
                    LIMIT ?
                ''', (user_id, limit))
                
                results = []
                for row in cursor.fetchall():
                    results.append(dict(row))
                return results
        except Exception as e:
            print(f"❌ Error getting user search history: {e}")
            return []
    
    # ==================== ANALYTICS QUERIES ====================
    
    def get_user_behavior_summary(self, user_id: int) -> Dict[str, Any]:
        """Get comprehensive user behavior summary."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                
                # Get user info
                cursor.execute('SELECT * FROM user_demographics WHERE user_id = ?', (user_id,))
                user_info = cursor.fetchone()
                
                # Get feedback count by event type
                cursor.execute('''
                    SELECT event_type, COUNT(*) as count
                    FROM feedback_db 
                    WHERE user_id = ?
                    GROUP BY event_type
                ''', (user_id,))
                event_counts = {row['event_type']: row['count'] for row in cursor.fetchall()}
                
                # Get recent activity
                cursor.execute('''
                    SELECT COUNT(*) as recent_count
                    FROM feedback_db 
                    WHERE user_id = ? AND session >= datetime('now', '-7 days')
                ''', (user_id,))
                recent_activity = cursor.fetchone()['recent_count']
                
                # Get favorite categories
                cursor.execute('''
                    SELECT p.category, COUNT(*) as count
                    FROM feedback_db f
                    JOIN product_catalog p ON f.product_id = p.product_id
                    WHERE f.user_id = ?
                    GROUP BY p.category
                    ORDER BY count DESC
                    LIMIT 5
                ''', (user_id,))
                favorite_categories = [dict(row) for row in cursor.fetchall()]
                
                return {
                    'user_info': dict(user_info) if user_info else None,
                    'event_counts': event_counts,
                    'recent_activity': recent_activity,
                    'favorite_categories': favorite_categories
                }
        except Exception as e:
            print(f"❌ Error getting user behavior summary: {e}")
            return {}
    
    def get_product_performance(self, product_id: str) -> Dict[str, Any]:
        """Get product performance metrics."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                
                # Get product info
                cursor.execute('SELECT * FROM product_catalog WHERE product_id = ?', (product_id,))
                product_info = cursor.fetchone()
                
                # Get event counts
                cursor.execute('''
                    SELECT event_type, COUNT(*) as count
                    FROM feedback_db 
                    WHERE product_id = ?
                    GROUP BY event_type
                ''', (product_id,))
                event_counts = {row['event_type']: row['count'] for row in cursor.fetchall()}
                
                # Get user engagement
                cursor.execute('''
                    SELECT COUNT(DISTINCT user_id) as unique_users
                    FROM feedback_db 
                    WHERE product_id = ?
                ''', (product_id,))
                unique_users = cursor.fetchone()['unique_users']
                
                return {
                    'product_info': dict(product_info) if product_info else None,
                    'event_counts': event_counts,
                    'unique_users': unique_users
                }
        except Exception as e:
            print(f"❌ Error getting product performance: {e}")
            return {}
    
    def get_all_users(self, limit: int = 100) -> List[Dict[str, Any]]:
        """Get all users from the database."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('''
                    SELECT * FROM user_demographics 
                    ORDER BY user_id ASC
                    LIMIT ?
                ''', (limit,))
                
                results = []
                for row in cursor.fetchall():
                    results.append(dict(row))
                return results
        except Exception as e:
            print(f"❌ Error getting all users: {e}")
            return []
    
    def get_total_feedback(self) -> int:
        """Get total number of feedback records."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('SELECT COUNT(*) as count FROM feedback_db')
                result = cursor.fetchone()
                return result['count'] if result else 0
        except Exception as e:
            print(f"❌ Error getting total feedback count: {e}")
            return 0
    
    def get_user_count(self) -> int:
        """Get total number of users."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('SELECT COUNT(*) as count FROM user_demographics')
                result = cursor.fetchone()
                return result['count'] if result else 0
        except Exception as e:
            print(f"❌ Error getting user count: {e}")
            return 0
    
    def get_product_count(self) -> int:
        """Get total number of products."""
        try:
            with self.db_manager as conn:
                cursor = conn.cursor()
                cursor.execute('SELECT COUNT(*) as count FROM product_catalog')
                result = cursor.fetchone()
                return result['count'] if result else 0
        except Exception as e:
            print(f"❌ Error getting product count: {e}")
            return 0
