"""
Data Manager for the MAB Recommendation System.
Manages data needed by recommender, client, and helps setup users in the engine.
"""

import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional, Tuple
from .data_controller import DataController
from .models import FeedbackDB, ProductCatalog, UserDemographics, SearchHistory


class DataManager:
    """Manages data operations and user setup for the recommendation engine."""
    
    def __init__(self, db_path: str = "recommender_system.db"):
        self.data_controller = DataController(db_path)
        self.user_sessions = {}  # Track active user sessions
        self.user_states = {}    # Track user states for RL engine
    
    # ==================== USER SESSION MANAGEMENT ====================
    
    def create_user_session(self, user_id: int, session_id: int = None) -> int:
        """Create a new user session for tracking interactions."""
        if session_id is None:
            session_id = int(datetime.now().timestamp())
        
        if user_id not in self.user_sessions:
            self.user_sessions[user_id] = []
        
        session_data = {
            'session_id': session_id,
            'start_time': datetime.now(),
            'interactions': [],
            'current_state': 'active'
        }
        
        self.user_sessions[user_id].append(session_data)
        return session_id
    
    def get_active_session(self, user_id: int) -> Optional[Dict[str, Any]]:
        """Get the currently active session for a user."""
        if user_id in self.user_sessions and self.user_sessions[user_id]:
            for session in reversed(self.user_sessions[user_id]):
                if session['current_state'] == 'active':
                    return session
        return None
    
    def end_user_session(self, user_id: int, session_id: int):
        """End a user session."""
        if user_id in self.user_sessions:
            for session in self.user_sessions[user_id]:
                if session['session_id'] == session_id:
                    session['current_state'] = 'ended'
                    session['end_time'] = datetime.now()
                    break
    
    # ==================== USER STATE MANAGEMENT FOR RL ====================
    
    def setup_user_in_engine(self, user_id: int) -> Dict[str, Any]:
        """Setup user state in the recommendation engine."""
        try:
            # Get user demographics
            user_info = self.data_controller.get_user(user_id)
            if not user_info:
                # Create new user if doesn't exist
                user_info = self._create_default_user(user_id)
            
            # Get user behavior summary
            behavior_summary = self.data_controller.get_user_behavior_summary(user_id)
            
            # Get recent feedback for context
            recent_feedback = self.data_controller.get_user_feedback(user_id, limit=50)
            
            # Create user state for RL engine
            user_state = {
                'user_id': user_id,
                'demographics': user_info,
                'behavior_summary': behavior_summary,
                'recent_interactions': recent_feedback,
                'current_session': self.get_active_session(user_id),
                'preferences': self._extract_user_preferences(recent_feedback),
                'cluster_id': user_info.get('user_cluster_id'),
                'last_active': user_info.get('last_active'),
                'engagement_level': self._calculate_engagement_level(behavior_summary),
                'exploration_score': self._calculate_exploration_score(recent_feedback)
            }
            
            self.user_states[user_id] = user_state
            return user_state
            
        except Exception as e:
            print(f"❌ Error setting up user in engine: {e}")
            return {}
    
    def update_user_state(self, user_id: int, new_interaction: Dict[str, Any]):
        """Update user state with new interaction."""
        if user_id in self.user_states:
            user_state = self.user_states[user_id]
            
            # Add new interaction to recent interactions
            user_state['recent_interactions'].insert(0, new_interaction)
            
            # Keep only last 100 interactions
            if len(user_state['recent_interactions']) > 100:
                user_state['recent_interactions'] = user_state['recent_interactions'][:100]
            
            # Update preferences
            user_state['preferences'] = self._extract_user_preferences(user_state['recent_interactions'])
            
            # Update engagement level
            behavior_summary = self.data_controller.get_user_behavior_summary(user_id)
            user_state['engagement_level'] = self._calculate_engagement_level(behavior_summary)
            user_state['exploration_score'] = self._calculate_exploration_score(user_state['recent_interactions'])
            
            # Update last active
            user_state['last_active'] = datetime.now()
    
    def get_user_state(self, user_id: int) -> Optional[Dict[str, Any]]:
        """Get current user state for the RL engine."""
        if user_id in self.user_states:
            return self.user_states[user_id]
        else:
            # Setup user if not already done
            return self.setup_user_in_engine(user_id)
    
    # ==================== DATA PREPARATION FOR RECOMMENDER ====================
    
    def get_recommendation_context(self, user_id: int, n_products: int = 10) -> Dict[str, Any]:
        """Get context data needed for making recommendations."""
        try:
            user_state = self.get_user_state(user_id)
            if not user_state:
                return {}
            
            # Get available products (arms for MAB)
            available_products = self._get_available_products_for_user(user_id, n_products)
            
            # Get user's current context
            context = {
                'user_id': user_id,
                'user_cluster': user_state.get('cluster_id'),
                'preferences': user_state.get('preferences', {}),
                'engagement_level': user_state.get('engagement_level', 'low'),
                'exploration_score': user_state.get('exploration_score', 0.5),
                'session_context': user_state.get('current_session', {}),
                'recent_interactions': user_state.get('recent_interactions', [])[:10],
                'available_products': available_products,
                'time_context': {
                    'hour': datetime.now().hour,
                    'day_of_week': datetime.now().weekday(),
                    'is_weekend': datetime.now().weekday() >= 5
                }
            }
            
            return context
            
        except Exception as e:
            print(f"❌ Error getting recommendation context: {e}")
            return {}
    
    def get_training_data_for_rl(self, user_id: int, limit: int = 1000) -> pd.DataFrame:
        """Get training data formatted for RL algorithms."""
        try:
            # Get user feedback history
            feedback_data = self.data_controller.get_user_feedback(user_id, limit=limit)
            
            if not feedback_data:
                return pd.DataFrame()
            
            # Convert to DataFrame
            df = pd.DataFrame(feedback_data)
            
            # Add derived features
            df['timestamp'] = pd.to_datetime(df['session'])
            df['hour'] = df['timestamp'].dt.hour
            df['day_of_week'] = df['timestamp'].dt.dayofweek
            df['is_weekend'] = df['day_of_week'].isin([5, 6])
            
            # Add user preference features
            user_state = self.get_user_state(user_id)
            if user_state:
                df['user_cluster'] = user_state.get('cluster_id')
                df['engagement_level'] = user_state.get('engagement_level', 'low')
            
            # Add product features
            df['price_range'] = pd.cut(df['price'], bins=5, labels=['very_low', 'low', 'medium', 'high', 'very_high'])
            
            # Add interaction sequence features
            df = df.sort_values('timestamp')
            df['interaction_sequence'] = range(len(df))
            df['time_since_last'] = df['timestamp'].diff().dt.total_seconds()
            
            return df
            
        except Exception as e:
            print(f"❌ Error getting training data for RL: {e}")
            return pd.DataFrame()
    
    # ==================== HELPER METHODS ====================
    
    def _create_default_user(self, user_id: int) -> Dict[str, Any]:
        """Create a default user profile."""
        default_user = UserDemographics(
            user_id=user_id,
            name=f"User_{user_id}",
            gender=None,
            age=None,
            user_cluster_id=None
        )
        
        user_id = self.data_controller.add_user(default_user)
        if user_id > 0:
            return self.data_controller.get_user(user_id)
        return {}
    
    def _extract_user_preferences(self, interactions: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Extract user preferences from interaction history."""
        if not interactions:
            return {}
        
        preferences = {
            'favorite_categories': {},
            'favorite_brands': {},
            'price_preference': {'min': float('inf'), 'max': 0, 'avg': 0},
            'event_preferences': {},
            'product_affinity': {}
        }
        
        prices = []
        for interaction in interactions:
            # Category preferences
            category = interaction.get('category')
            if category:
                preferences['favorite_categories'][category] = preferences['favorite_categories'].get(category, 0) + 1
            
            # Brand preferences
            brand = interaction.get('brand_name')
            if brand:
                preferences['favorite_brands'][brand] = preferences['favorite_brands'].get(brand, 0) + 1
            
            # Price preferences
            price = interaction.get('price')
            if price and price > 0:
                prices.append(price)
                preferences['price_preference']['min'] = min(preferences['price_preference']['min'], price)
                preferences['price_preference']['max'] = max(preferences['price_preference']['max'], price)
            
            # Event type preferences
            event_type = interaction.get('event_type')
            if event_type:
                preferences['event_preferences'][event_type] = preferences['event_preferences'].get(event_type, 0) + 1
            
            # Product affinity
            product_id = interaction.get('product_id')
            if product_id:
                preferences['product_affinity'][product_id] = preferences['product_affinity'].get(product_id, 0) + 1
        
        # Calculate average price
        if prices:
            preferences['price_preference']['avg'] = sum(prices) / len(prices)
        
        # Sort preferences by frequency
        preferences['favorite_categories'] = dict(sorted(
            preferences['favorite_categories'].items(), 
            key=lambda x: x[1], reverse=True
        )[:5])
        
        preferences['favorite_brands'] = dict(sorted(
            preferences['favorite_brands'].items(), 
            key=lambda x: x[1], reverse=True
        )[:5])
        
        return preferences
    
    def _calculate_engagement_level(self, behavior_summary: Dict[str, Any]) -> str:
        """Calculate user engagement level."""
        recent_activity = behavior_summary.get('recent_activity', 0)
        
        if recent_activity >= 20:
            return 'high'
        elif recent_activity >= 10:
            return 'medium'
        else:
            return 'low'
    
    def _calculate_exploration_score(self, interactions: List[Dict[str, Any]]) -> float:
        """Calculate user exploration score (0-1)."""
        if not interactions:
            return 0.5
        
        # Count unique categories and brands
        unique_categories = len(set(interaction.get('category') for interaction in interactions if interaction.get('category')))
        unique_brands = len(set(interaction.get('brand_name') for interaction in interactions if interaction.get('brand_name')))
        
        # Normalize by total interactions
        total_interactions = len(interactions)
        if total_interactions == 0:
            return 0.5
        
        category_diversity = unique_categories / total_interactions
        brand_diversity = unique_brands / total_interactions
        
        # Combine scores (higher = more exploratory)
        exploration_score = (category_diversity + brand_diversity) / 2
        return min(1.0, max(0.0, exploration_score))
    
    def _get_available_products_for_user(self, user_id: int, n_products: int) -> List[Dict[str, Any]]:
        """Get available products for recommendation (arms for MAB)."""
        try:
            # Get user preferences
            user_state = self.get_user_state(user_id)
            preferences = user_state.get('preferences', {})
            
            # Get all products
            all_products = self.data_controller.get_all_products(limit=1000)
            
            if not all_products:
                return []
            
            # Score products based on user preferences
            scored_products = []
            for product in all_products:
                score = 0
                
                # Category preference score
                category = product.get('category')
                if category in preferences.get('favorite_categories', {}):
                    score += preferences['favorite_categories'][category] * 0.3
                
                # Brand preference score
                brand = product.get('brand_name')
                if brand in preferences.get('favorite_brands', {}):
                    score += preferences['favorite_brands'][brand] * 0.2
                
                # Price preference score
                price = product.get('price', 0)
                price_pref = preferences.get('price_preference', {})
                if price_pref.get('avg', 0) > 0:
                    price_diff = abs(price - price_pref['avg']) / price_pref['avg']
                    if price_diff < 0.2:  # Within 20% of preferred price
                        score += 0.2
                
                scored_products.append((product, score))
            
            # Sort by score and return top N
            scored_products.sort(key=lambda x: x[1], reverse=True)
            return [product for product, score in scored_products[:n_products]]
            
        except Exception as e:
            print(f"❌ Error getting available products: {e}")
            return []
    
    # ==================== DATA EXPORT AND UTILITIES ====================
    
    def export_user_data(self, user_id: int, format: str = 'json') -> Any:
        """Export user data in specified format."""
        try:
            user_state = self.get_user_state(user_id)
            if not user_state:
                return None
            
            if format.lower() == 'json':
                return user_state
            elif format.lower() == 'csv':
                # Convert to DataFrame and export
                df = self.get_training_data_for_rl(user_id)
                return df.to_csv(index=False)
            else:
                return user_state
                
        except Exception as e:
            print(f"❌ Error exporting user data: {e}")
            return None
    
    def get_system_statistics(self) -> Dict[str, Any]:
        """Get overall system statistics."""
        try:
            # This would need to be implemented based on your specific needs
            # For now, return basic structure
            return {
                'total_users': len(self.user_states),
                'active_sessions': len([s for s in self.user_sessions.values() if any(s['current_state'] == 'active' for s in s)]),
                'total_products': len(self.data_controller.get_all_products()),
                'system_status': 'operational'
            }
        except Exception as e:
            print(f"❌ Error getting system statistics: {e}")
            return {}
