"""
Enhanced Multi-Armed Bandit Algorithms with Data and Clustering Integration.
These algorithms use the new data and clustering layers for improved recommendations.
"""

import numpy as np
import pandas as pd
from typing import List, Dict, Any, Optional, Tuple
from abc import ABC, abstractmethod
import random

from .mabalgorithms import (
    RandomMAB, EpsilonGreedyMAB, UpperConfidenceBoundMAB, 
    ThompsonSamplingMAB
)
from db import DataManager
from clustering import ClusteringManager


class EnhancedMABBase(ABC):
    """Base class for enhanced MAB algorithms."""
    
    def __init__(self, n_arms: int, n_suggestions: int, 
                 data_manager: DataManager, clustering_manager: ClusteringManager):
        self.n_arms = n_arms
        self.n_suggestions = n_suggestions
        self.data_manager = data_manager
        self.clustering_manager = clustering_manager
        
        # Algorithm state
        self.arm_counts = np.zeros(n_arms)
        self.arm_rewards = np.zeros(n_arms)
        self.arm_values = np.zeros(n_arms)
        
        # Enhanced features
        self.user_contexts = {}
        self.cluster_insights = {}
        self.price_optimization = {}
        
    @abstractmethod
    def select_multiple_arms(self, user_id: int = None, context: Dict[str, Any] = None) -> np.ndarray:
        """Select multiple arms (products) for recommendation."""
        pass
    
    def update(self, arms: List[int], rewards: List[float], user_id: int = None):
        """Update algorithm with observed rewards."""
        for arm, reward in zip(arms, rewards):
            self.arm_counts[arm] += 1
            self.arm_rewards[arm] += reward
            self.arm_values[arm] = self.arm_rewards[arm] / self.arm_counts[arm]
    
    def get_algorithm_info(self) -> Dict[str, Any]:
        """Get information about the algorithm state."""
        return {
            'n_arms': self.n_arms,
            'n_suggestions': self.n_suggestions,
            'arm_counts': self.arm_counts.tolist(),
            'arm_rewards': self.arm_rewards.tolist(),
            'arm_values': self.arm_values.tolist(),
            'total_pulls': np.sum(self.arm_counts),
            'total_reward': np.sum(self.arm_rewards)
        }


class EnhancedEpsilonGreedyMAB(EnhancedMABBase):
    """Enhanced Epsilon-Greedy algorithm with clustering insights."""
    
    def __init__(self, n_arms: int, n_suggestions: int, 
                 data_manager: DataManager, clustering_manager: ClusteringManager,
                 epsilon: float = 0.1, exploration_bonus: float = 0.2):
        super().__init__(n_arms, n_suggestions, data_manager, clustering_manager)
        self.epsilon = epsilon
        self.exploration_bonus = exploration_bonus
        
    def select_multiple_arms(self, user_id: int = None, context: Dict[str, Any] = None) -> np.ndarray:
        """Select arms using enhanced epsilon-greedy strategy."""
        if user_id is None:
            # Fallback to basic selection
            return self._basic_arm_selection()
        
        # Get enhanced context
        enhanced_context = self._get_enhanced_context(user_id, context)
        
        # Decide between exploration and exploitation
        if random.random() < self.epsilon:
            # Exploration: use clustering insights
            return self._exploration_selection(enhanced_context)
        else:
            # Exploitation: use learned values + context
            return self._exploitation_selection(enhanced_context)
    
    def _get_enhanced_context(self, user_id: int, context: Dict[str, Any] = None) -> Dict[str, Any]:
        """Get enhanced context using data and clustering layers."""
        enhanced_context = {}
        
        try:
            # Get user state and preferences
            user_state = self.data_manager.get_user_state(user_id)
            enhanced_context['user_state'] = user_state
            
            # Get clustering insights
            if self.clustering_manager:
                user_cluster = self.clustering_manager.get_user_cluster(user_id)
                enhanced_context['user_cluster'] = user_cluster
                
                # Get cluster-specific insights
                insights = self.clustering_manager.get_clustering_insights()
                enhanced_context['cluster_insights'] = insights
                
                # Get optimal price ranges
                optimal_prices = self.clustering_manager.get_optimal_price_range('purchase', min_score=0.3)
                enhanced_context['optimal_prices'] = optimal_prices
            
            # Merge with provided context
            if context:
                enhanced_context.update(context)
                
        except Exception as e:
            print(f"⚠️  Error getting enhanced context: {e}")
        
        return enhanced_context
    
    def _exploration_selection(self, context: Dict[str, Any]) -> np.ndarray:
        """Select arms for exploration using clustering insights."""
        try:
            # Get user cluster preferences
            user_cluster = context.get('user_cluster', -1)
            cluster_insights = context.get('cluster_insights', {})
            
            # Create exploration scores
            exploration_scores = np.zeros(self.n_arms)
            
            # Base exploration score (inverse of visit count)
            exploration_scores += 1.0 / (self.arm_counts + 1)
            
            # Cluster-based exploration bonus
            if user_cluster is not None and user_cluster >= 0 and cluster_insights:
                exploration_scores += self._get_cluster_exploration_bonus(user_cluster, cluster_insights)
            
            # Price optimization exploration
            optimal_prices = context.get('optimal_prices', [])
            if optimal_prices:
                exploration_scores += self._get_price_exploration_bonus(optimal_prices)
            
            # Select top arms
            selected_arms = np.argsort(exploration_scores)[-self.n_suggestions:][::-1]
            return selected_arms
            
        except Exception as e:
            print(f"⚠️  Error in exploration selection: {e}")
            return self._basic_arm_selection()
    
    def _exploitation_selection(self, context: Dict[str, Any]) -> np.ndarray:
        """Select arms for exploitation using learned values and context."""
        try:
            # Combine learned values with context insights
            combined_scores = np.copy(self.arm_values)
            
            # Add context-based bonuses
            user_state = context.get('user_state', {})
            preferences = user_state.get('preferences', {})
            
            # User preference bonus
            combined_scores += self._get_preference_bonus(preferences)
            
            # Clustering bonus
            user_cluster = context.get('user_cluster', -1)
            if user_cluster is not None and user_cluster >= 0:
                combined_scores += self._get_cluster_exploitation_bonus(user_cluster, context)
            
            # Price optimization bonus
            optimal_prices = context.get('optimal_prices', [])
            if optimal_prices:
                combined_scores += self._get_price_exploitation_bonus(optimal_prices)
            
            # Select top arms
            selected_arms = np.argsort(combined_scores)[-self.n_suggestions:][::-1]
            return selected_arms
            
        except Exception as e:
            print(f"⚠️  Error in exploitation selection: {e}")
            return self._basic_arm_selection()
    
    def _get_cluster_exploration_bonus(self, user_cluster: int, insights: Dict[str, Any]) -> np.ndarray:
        """Get exploration bonus based on user cluster."""
        bonus = np.zeros(self.n_arms)
        
        try:
            # This would use actual cluster insights
            # For now, apply simple heuristics
            if user_cluster == 0:  # Tech-savvy cluster
                bonus += 0.3
            elif user_cluster == 1:  # Budget-conscious cluster
                bonus += 0.2
            elif user_cluster == 2:  # Premium cluster
                bonus += 0.4
                
        except Exception as e:
            print(f"⚠️  Error getting cluster exploration bonus: {e}")
        
        return bonus
    
    def _get_cluster_exploitation_bonus(self, user_cluster: int, context: Dict[str, Any]) -> np.ndarray:
        """Get exploitation bonus based on user cluster."""
        bonus = np.zeros(self.n_arms)
        
        try:
            # This would use actual cluster insights
            # For now, apply simple heuristics
            if user_cluster == 0:  # Tech-savvy cluster
                bonus += 0.2
            elif user_cluster == 1:  # Budget-conscious cluster
                bonus += 0.1
            elif user_cluster == 2:  # Premium cluster
                bonus += 0.3
                
        except Exception as e:
            print(f"⚠️  Error getting cluster exploitation bonus: {e}")
        
        return bonus
    
    def _get_price_exploration_bonus(self, optimal_prices: List[Tuple[float, float]]) -> np.ndarray:
        """Get exploration bonus based on price optimization."""
        bonus = np.zeros(self.n_arms)
        
        try:
            # This would use actual product price data
            # For now, apply simple heuristics
            bonus += 0.1  # Small bonus for price-optimized exploration
            
        except Exception as e:
            print(f"⚠️  Error getting price exploration bonus: {e}")
        
        return bonus
    
    def _get_price_exploitation_bonus(self, optimal_prices: List[Tuple[float, float]]) -> np.ndarray:
        """Get exploitation bonus based on price optimization."""
        bonus = np.zeros(self.n_arms)
        
        try:
            # This would use actual product price data
            # For now, apply simple heuristics
            bonus += 0.2  # Bonus for price-optimized exploitation
            
        except Exception as e:
            print(f"⚠️  Error getting price exploitation bonus: {e}")
        
        return bonus
    
    def _get_preference_bonus(self, preferences: Dict[str, Any]) -> np.ndarray:
        """Get bonus based on user preferences."""
        bonus = np.zeros(self.n_arms)
        
        try:
            # This would use actual product data
            # For now, apply simple heuristics
            favorite_categories = preferences.get('favorite_categories', {})
            favorite_brands = preferences.get('favorite_brands', {})
            
            if favorite_categories:
                bonus += 0.1
            if favorite_brands:
                bonus += 0.1
                
        except Exception as e:
            print(f"⚠️  Error getting preference bonus: {e}")
        
        return bonus
    
    def _basic_arm_selection(self) -> np.ndarray:
        """Basic arm selection without context."""
        # Random selection for exploration
        if random.random() < self.epsilon:
            return np.random.choice(self.n_arms, self.n_suggestions, replace=False)
        else:
            # Greedy selection
            return np.argsort(self.arm_values)[-self.n_suggestions:][::-1]


class EnhancedContextualMAB(EnhancedMABBase):
    """Enhanced Contextual MAB algorithm with rich context."""
    
    def __init__(self, n_arms: int, n_suggestions: int, 
                 data_manager: DataManager, clustering_manager: ClusteringManager,
                 context_dim: int = 20, learning_rate: float = 0.01):
        super().__init__(n_arms, n_suggestions, data_manager, clustering_manager)
        self.context_dim = context_dim
        self.learning_rate = learning_rate
        
        # Contextual bandit parameters
        self.theta = np.random.randn(n_arms, context_dim) * 0.01
        self.context_history = []
        self.reward_history = []
        
    def select_multiple_arms(self, user_id: int = None, context: Dict[str, Any] = None) -> np.ndarray:
        """Select arms using contextual information."""
        if user_id is None:
            return self._basic_arm_selection()
        
        # Get enhanced context
        enhanced_context = self._get_enhanced_context(user_id, context)
        
        # Convert context to feature vector
        context_vector = self._context_to_vector(enhanced_context)
        
        # Calculate expected rewards for each arm
        expected_rewards = np.dot(self.theta, context_vector)
        
        # Add exploration noise
        exploration_noise = np.random.normal(0, 0.1, self.n_arms)
        expected_rewards += exploration_noise
        
        # Select top arms
        selected_arms = np.argsort(expected_rewards)[-self.n_suggestions:][::-1]
        
        # Store context for learning
        self.context_history.append(context_vector)
        
        return selected_arms
    
    def update(self, arms: List[int], rewards: List[float], user_id: int = None):
        """Update contextual bandit parameters."""
        if not self.context_history:
            return
        
        # Get the context used for this action
        context_vector = self.context_history.pop(0)
        
        # Update theta for each arm
        for arm, reward in zip(arms, rewards):
            # Simple gradient update
            prediction = np.dot(self.theta[arm], context_vector)
            error = reward - prediction
            
            # Update theta
            self.theta[arm] += self.learning_rate * error * context_vector
            
            # Also update basic statistics
            self.arm_counts[arm] += 1
            self.arm_rewards[arm] += reward
            self.arm_values[arm] = self.arm_rewards[arm] / self.arm_counts[arm]
    
    def _context_to_vector(self, context: Dict[str, Any]) -> np.ndarray:
        """Convert context dictionary to feature vector."""
        feature_vector = np.zeros(self.context_dim)
        
        try:
            idx = 0
            
            # User engagement features
            user_state = context.get('user_state', {})
            if user_state:
                feature_vector[idx] = 1.0 if user_state.get('engagement_level', 'low') == 'high' else 0.0
                idx += 1
                feature_vector[idx] = 1.0 if user_state.get('engagement_level', 'low') == 'medium' else 0.0
                idx += 1
                exploration_score = user_state.get('exploration_score', 0.5)
                feature_vector[idx] = float(exploration_score) if exploration_score is not None else 0.5
                idx += 1
                cluster_id = user_state.get('cluster_id', 0)
                feature_vector[idx] = float(cluster_id) if cluster_id is not None else 0.0
                idx += 1
            
            # Session features
            session_context = context.get('session_context', {})
            if session_context:
                feature_vector[idx] = min(session_context.get('duration_hours', 0), 24) / 24
                idx += 1
                feature_vector[idx] = min(session_context.get('interaction_count', 0), 100) / 100
                idx += 1
            
            # Clustering features
            user_cluster = context.get('user_cluster', -1)
            if user_cluster is not None and user_cluster >= 0:
                feature_vector[idx] = user_cluster / 10  # Normalize cluster ID
                idx += 1
            
            # Price optimization features
            optimal_prices = context.get('optimal_prices', [])
            if optimal_prices:
                feature_vector[idx] = len(optimal_prices) / 10  # Normalize count
                idx += 1
            
            # Time features
            current_hour = context.get('current_hour', 12)
            feature_vector[idx] = current_hour / 24
            idx += 1
            
            current_day = context.get('current_day', 3)
            feature_vector[idx] = current_day / 7
            idx += 1
            
            # Normalize remaining features
            while idx < self.context_dim:
                feature_vector[idx] = np.random.normal(0, 0.1)
                idx += 1
            
            # Normalize the entire vector
            norm = np.linalg.norm(feature_vector)
            if norm > 0:
                feature_vector = feature_vector / norm
                
        except Exception as e:
            print(f"⚠️  Error converting context to vector: {e}")
            feature_vector = np.random.normal(0, 0.1, self.context_dim)
        
        return feature_vector
    
    def _get_enhanced_context(self, user_id: int, context: Dict[str, Any] = None) -> Dict[str, Any]:
        """Get enhanced context using data and clustering layers."""
        enhanced_context = {}
        
        try:
            # Get user state and preferences
            user_state = self.data_manager.get_user_state(user_id)
            enhanced_context['user_state'] = user_state
            
            # Get clustering insights
            if self.clustering_manager:
                user_cluster = self.clustering_manager.get_user_cluster(user_id)
                enhanced_context['user_cluster'] = user_cluster
                
                # Get cluster-specific insights
                insights = self.clustering_manager.get_clustering_insights()
                enhanced_context['cluster_insights'] = insights
                
                # Get optimal price ranges
                optimal_prices = self.clustering_manager.get_optimal_price_range('purchase', min_score=0.3)
                enhanced_context['optimal_prices'] = optimal_prices
            
            # Merge with provided context
            if context:
                enhanced_context.update(context)
                
        except Exception as e:
            print(f"⚠️  Error getting enhanced context: {e}")
        
        return enhanced_context
    
    def _basic_arm_selection(self) -> np.ndarray:
        """Basic arm selection without context."""
        return np.random.choice(self.n_arms, self.n_suggestions, replace=False)


class EnhancedEnsembleMAB(EnhancedMABBase):
    """Enhanced ensemble MAB combining multiple algorithms with context."""
    
    def __init__(self, n_arms: int, n_suggestions: int, 
                 data_manager: DataManager, clustering_manager: ClusteringManager,
                 algorithms: List[str] = None):
        super().__init__(n_arms, n_suggestions, data_manager, clustering_manager)
        
        # Initialize base algorithms
        self.algorithms = algorithms or ['epsilon_greedy', 'ucb', 'thompson']
        self.base_algorithms = self._initialize_base_algorithms()
        
        # Ensemble weights
        self.algorithm_weights = np.ones(len(self.base_algorithms)) / len(self.base_algorithms)
        self.algorithm_performances = np.zeros(len(self.base_algorithms))
        
    def _initialize_base_algorithms(self) -> List[Any]:
        """Initialize base MAB algorithms."""
        base_algorithms = []
        
        for alg_name in self.algorithms:
            if alg_name == 'epsilon_greedy':
                base_algorithms.append(EpsilonGreedyMAB(self.n_arms, self.n_suggestions, epsilon=0.1))
            elif alg_name == 'ucb':
                base_algorithms.append(UpperConfidenceBoundMAB(self.n_arms, self.n_suggestions, c=2.0))
            elif alg_name == 'thompson':
                base_algorithms.append(ThompsonSamplingMAB(self.n_arms, self.n_suggestions))
            elif alg_name == 'random':
                base_algorithms.append(RandomMAB(self.n_arms, self.n_suggestions))
        
        return base_algorithms
    
    def select_multiple_arms(self, user_id: int = None, context: Dict[str, Any] = None) -> np.ndarray:
        """Select arms using ensemble of algorithms."""
        if not self.base_algorithms:
            return self._basic_arm_selection()
        
        # Get enhanced context
        enhanced_context = self._get_enhanced_context(user_id, context)
        
        # Get recommendations from each algorithm
        all_recommendations = []
        for algorithm in self.base_algorithms:
            try:
                if hasattr(algorithm, 'select_multiple_arms'):
                    recommendations = algorithm.select_multiple_arms()
                    all_recommendations.append(recommendations)
                else:
                    # Fallback for algorithms without enhanced selection
                    recommendations = np.random.choice(self.n_arms, self.n_suggestions, replace=False)
                    all_recommendations.append(recommendations)
            except Exception as e:
                print(f"⚠️  Error in algorithm {type(algorithm).__name__}: {e}")
                recommendations = np.random.choice(self.n_arms, self.n_suggestions, replace=False)
                all_recommendations.append(recommendations)
        
        # Combine recommendations using weighted voting
        final_recommendations = self._weighted_voting(all_recommendations, enhanced_context)
        
        return final_recommendations
    
    def _weighted_voting(self, all_recommendations: List[np.ndarray], context: Dict[str, Any]) -> np.ndarray:
        """Combine recommendations using weighted voting."""
        try:
            # Count votes for each arm
            arm_votes = np.zeros(self.n_arms)
            
            for i, recommendations in enumerate(all_recommendations):
                weight = self.algorithm_weights[i]
                
                for arm in recommendations:
                    if 0 <= arm < self.n_arms:
                        arm_votes[arm] += weight
            
            # Apply context-based adjustments
            arm_votes = self._apply_context_adjustments(arm_votes, context)
            
            # Select top arms
            selected_arms = np.argsort(arm_votes)[-self.n_suggestions:][::-1]
            
            return selected_arms
            
        except Exception as e:
            print(f"⚠️  Error in weighted voting: {e}")
            return np.random.choice(self.n_arms, self.n_suggestions, replace=False)
    
    def _apply_context_adjustments(self, arm_votes: np.ndarray, context: Dict[str, Any]) -> np.ndarray:
        """Apply context-based adjustments to arm votes."""
        try:
            adjusted_votes = np.copy(arm_votes)
            
            # User preference adjustments
            user_state = context.get('user_state', {})
            preferences = user_state.get('preferences', {})
            
            if preferences:
                # This would use actual product data
                # For now, apply simple heuristics
                pass
            
            # Clustering adjustments
            user_cluster = context.get('user_cluster', -1)
            if user_cluster is not None and user_cluster >= 0:
                # This would use actual cluster insights
                # For now, apply simple heuristics
                pass
            
            # Price optimization adjustments
            optimal_prices = context.get('optimal_prices', [])
            if optimal_prices:
                # This would use actual price data
                # For now, apply simple heuristics
                pass
            
            return adjusted_votes
            
        except Exception as e:
            print(f"⚠️  Error applying context adjustments: {e}")
            return arm_votes
    
    def update(self, arms: List[int], rewards: List[float], user_id: int = None):
        """Update ensemble algorithm."""
        # Update base algorithms
        for algorithm in self.base_algorithms:
            try:
                algorithm.update(arms, rewards)
            except Exception as e:
                print(f"⚠️  Error updating algorithm {type(algorithm).__name__}: {e}")
        
        # Update ensemble weights based on performance
        self._update_ensemble_weights(rewards)
        
        # Update base class statistics
        super().update(arms, rewards, user_id)
    
    def _update_ensemble_weights(self, rewards: List[float]):
        """Update algorithm weights based on performance."""
        try:
            # Simple performance tracking
            avg_reward = np.mean(rewards) if rewards else 0
            
            # Update performance for each algorithm
            for i in range(len(self.base_algorithms)):
                # This is a simplified update - in practice, you'd track individual algorithm performance
                self.algorithm_performances[i] = 0.9 * self.algorithm_performances[i] + 0.1 * avg_reward
            
            # Update weights using softmax
            if np.sum(self.algorithm_performances) > 0:
                exp_performances = np.exp(self.algorithm_performances)
                self.algorithm_weights = exp_performances / np.sum(exp_performances)
            
        except Exception as e:
            print(f"⚠️  Error updating ensemble weights: {e}")
    
    def _get_enhanced_context(self, user_id: int, context: Dict[str, Any] = None) -> Dict[str, Any]:
        """Get enhanced context using data and clustering layers."""
        enhanced_context = {}
        
        try:
            # Get user state and preferences
            user_state = self.data_manager.get_user_state(user_id)
            enhanced_context['user_state'] = user_state
            
            # Get clustering insights
            if self.clustering_manager:
                user_cluster = self.clustering_manager.get_user_cluster(user_id)
                enhanced_context['user_cluster'] = user_cluster
                
                # Get cluster-specific insights
                insights = self.clustering_manager.get_clustering_insights()
                enhanced_context['cluster_insights'] = insights
                
                # Get optimal price ranges
                optimal_prices = self.clustering_manager.get_optimal_price_range('purchase', min_score=0.3)
                enhanced_context['optimal_prices'] = optimal_prices
            
            # Merge with provided context
            if context:
                enhanced_context.update(context)
                
        except Exception as e:
            print(f"⚠️  Error getting enhanced context: {e}")
        
        return enhanced_context
    
    def _basic_arm_selection(self) -> np.ndarray:
        """Basic arm selection without context."""
        return np.random.choice(self.n_arms, self.n_suggestions, replace=False)
