"""
Unified Multi-Armed Bandit Recommendation Environment.
Supports both CSV and database data sources with optional clustering integration.
"""

import gymnasium as gym
import numpy as np
import pandas as pd
import random
import os
try:
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle
    MATPLOTLIB_AVAILABLE = True
except Exception:
    plt = None
    Rectangle = None
    MATPLOTLIB_AVAILABLE = False
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, Tuple, Union

from .logger import RecommenderLogger, performance_tracker

# Optional imports for enhanced features
try:
    from db import DataController, DataManager
    # from clustering import ClusteringManager
    ENHANCED_FEATURES_AVAILABLE = True
except ImportError:
    ENHANCED_FEATURES_AVAILABLE = False
    DataController = None
    DataManager = None
    # ClusteringManager = None


class UnifiedRecommendationEnv(gym.Env):
    """
    Unified Multi-Armed Bandit environment supporting both CSV and database data sources.
    Each 'arm' represents a product that can be recommended to users.
    """

    def __init__(self, 
                 data_source: Union[str, Dict[str, Any]],
                 max_arms: int = 50, 
                 max_steps: int = 100, 
                 n_suggestions: int = 1, 
                 seed: int = None,
                 # Enhanced features (optional)
                 use_clustering: bool = False,
                 use_user_context: bool = False,
                 use_price_optimization: bool = False,
                 # CSV-specific parameters
                 product_col: str = "product_id",
                 user_col: str = "user_id", 
                 brand_col: str = "brand",
                 category_col: str = "category_id",
                 category_code_col: str = "category_code",
                 event_col: str = "event_type",
                 time_col: str = "event_time",
                 base_rewards: Dict[str, float] = None,
                 repeat_penalty: float = -0.2,
                 brand_bonus: float = 0.2,
                 category_bonus: float = 0.3,
                 relevance_window: int = 5):
        """
        Initialize the Unified Multi-Armed Bandit Recommendation Environment.

        Parameters:
        - data_source: Either a CSV file path (str) or database config (dict)
        - max_arms: Maximum number of arms (products) to consider
        - max_steps: Maximum number of steps in an episode
        - n_suggestions: Number of products to recommend simultaneously
        - seed: Random seed for reproducibility
        - use_clustering: Whether to use clustering insights (requires database)
        - use_user_context: Whether to use enhanced user context (requires database)
        - use_price_optimization: Whether to use price optimization (requires database)
        - product_col: Column name for product IDs (CSV only)
        - user_col: Column name for user IDs (CSV only)
        - brand_col: Column name for brand information (CSV only)
        - category_col: Column name for category IDs (CSV only)
        - category_code_col: Column name for category code information (CSV only)
        - event_col: Column name for event type information (CSV only)
        - time_col: Column name for event time information (CSV only)
        - base_rewards: Base rewards for each event type (CSV only)
        - repeat_penalty: Penalty for repeated products (CSV only)
        - brand_bonus: Bonus for same brand (CSV only)
        - category_bonus: Bonus for same category (CSV only)
        - relevance_window: Window for considering past interactions (CSV only)
        """
        
        # Validate parameters
        assert isinstance(max_arms, int) and max_arms > 0, "Max arms must be a positive integer."
        assert isinstance(n_suggestions, int) and n_suggestions > 0, "n_suggestions must be a positive integer."
        assert n_suggestions <= max_arms, "n_suggestions cannot exceed max_arms."
        assert seed is None or isinstance(seed, int), "Seed must be an integer."

        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)

        # Initialize components
        self.max_arms = max_arms
        self.max_steps = max_steps
        self.n_suggestions = n_suggestions
        
        # Feature flags
        self.use_clustering = use_clustering and ENHANCED_FEATURES_AVAILABLE
        self.use_user_context = use_user_context and ENHANCED_FEATURES_AVAILABLE
        self.use_price_optimization = use_price_optimization and ENHANCED_FEATURES_AVAILABLE
        
        # Determine data source type
        self.data_source_type = self._determine_data_source_type(data_source)
        
        # Initialize data source
        print(f"🔧 Initializing Unified MAB Environment with {self.data_source_type} data source...")
        self._initialize_data_source(data_source)
        
        # Initialize enhanced features if requested and available
        if self.use_clustering or self.use_user_context or self.use_price_optimization:
            if not ENHANCED_FEATURES_AVAILABLE:
                print("⚠️  Enhanced features requested but not available. Falling back to basic mode.")
                self.use_clustering = False
                self.use_user_context = False
                self.use_price_optimization = False
            else:
                self._initialize_enhanced_features()
        
        # Environment state
        self.current_step = 0
        self.current_user_id = None
        self.user_history = {}
        self.product_history = {}
        self.current_session = None
        
        # Render-related attributes
        self.render_mode = None
        self.fig = None
        
        # Performance tracking
        self.episode_rewards = []
        self.episode_actions = []
        self.episode_contexts = []
        
        # Initialize logger
        self.logger = RecommenderLogger(name=f"unified_recommendation_env_{self.max_arms}arms")
        
        print("✅ Unified MAB Environment initialized successfully!")

    def _determine_data_source_type(self, data_source: Union[str, Dict[str, Any]]) -> str:
        """Determine the type of data source."""
        if isinstance(data_source, str):
            if data_source.endswith('.csv'):
                return "csv"
            else:
                return "database"
        elif isinstance(data_source, dict):
            return "database"
        else:
            raise ValueError("data_source must be a CSV file path (str) or database config (dict)")

    def _initialize_data_source(self, data_source: Union[str, Dict[str, Any]]):
        """Initialize the data source based on type."""
        if self.data_source_type == "csv":
            self._initialize_csv_data_source(data_source)
        else:
            self._initialize_database_data_source(data_source)

    def _initialize_csv_data_source(self, csv_path: str):
        """Initialize CSV data source (legacy RecommendationEnv functionality)."""
        assert isinstance(csv_path, str), "CSV path must be a string."
        assert os.path.isfile(csv_path), f"CSV file {csv_path} does not exist."
        
        # Load CSV data
        self.data = pd.read_csv(csv_path)
        
        # Set default rewards if not provided
        if not hasattr(self, 'base_rewards') or self.base_rewards is None:
            self.base_rewards = {
                'view': 0.1, 
                'cart': 0.33, 
                'purchase': 1.0, 
                'remove_from_cart': -0.5
            }
        
        # Extract unique values
        self.n_users = self.data[self.user_col].nunique()
        self.n_products = self.data[self.product_col].nunique()
        self.n_brands = self.data[self.brand_col].nunique()
        self.n_categories = self.data[self.category_col].nunique()
        
        # Set up action and observation spaces
        self.action_space = gym.spaces.MultiDiscrete([self.n_products] * self.n_suggestions)
        self.observation_space = gym.spaces.Dict({
            'user_id': gym.spaces.Box(low=0, high=self.n_users, shape=(), dtype=np.int32),
            'session_steps': gym.spaces.Box(low=0, high=self.max_steps, shape=(), dtype=np.int32),
            'previous_brands': gym.spaces.Box(low=0, high=self.n_brands, shape=(self.relevance_window,), dtype=np.int32),
            'previous_categories': gym.spaces.Box(low=0, high=self.n_categories, shape=(self.relevance_window,), dtype=np.int32),
            'previous_products': gym.spaces.Box(low=0, high=self.n_products, shape=(self.relevance_window,), dtype=np.int32)
        })
        
        print(f"📊 Loaded CSV data: {self.data.shape[0]} interactions, {self.n_users} users, {self.n_products} products")

    def _initialize_database_data_source(self, db_config: Dict[str, Any]):
        """Initialize database data source (EnhancedRecommendationEnv functionality)."""
        if not ENHANCED_FEATURES_AVAILABLE:
            raise ImportError("Database features require db and clustering modules")
        
        # Extract database path
        db_path = db_config.get('db_path', 'recommender_system.db')
        
        # Initialize data and clustering layers
        self.data_controller = DataController(db_path)
        self.data_manager = DataManager(db_path)
        
        # if self.use_clustering:
        #     # self.clustering_manager = ClusteringManager()
        # else:
        self.clustering_manager = None
        
        # Set up action and observation spaces (will be updated based on available products)
        self.action_space = gym.spaces.MultiDiscrete([self.max_arms] * self.n_suggestions)
        self.observation_space = gym.spaces.Dict({
            'user_id': gym.spaces.Box(low=0, high=10000, shape=(), dtype=np.int32),
            'session_steps': gym.spaces.Box(low=0, high=self.max_steps, shape=(), dtype=np.int32),
            'user_context': gym.spaces.Box(low=-1, high=1, shape=(10,), dtype=np.float32),
            'available_products': gym.spaces.Box(low=0, high=self.max_arms, shape=(self.max_arms,), dtype=np.int32)
        })
        
        print(f"🗄️  Initialized database data source: {db_path}")

    def _initialize_enhanced_features(self):
        """Initialize enhanced features (clustering, user context, price optimization)."""
        if not ENHANCED_FEATURES_AVAILABLE:
            return
        
        print("🔧 Initializing enhanced features...")
        
        # Initialize clustering if requested
        if self.use_clustering and self.clustering_manager:
            try:
                # Run clustering pipeline
                products_data = self.data_controller.get_all_products()
                users_data = self.data_controller.get_all_users()
                feedback_data = self.data_controller.get_all_feedback()
                
                if len(products_data) > 0 and len(users_data) > 0 and len(feedback_data) > 0:
                    self.clustering_manager.run_full_clustering(
                        products_data, users_data, feedback_data
                    )
                    print("✅ Clustering pipeline completed")
                else:
                    print("⚠️  Insufficient data for clustering, disabling clustering")
                    self.use_clustering = False
            except Exception as e:
                print(f"⚠️  Clustering initialization failed: {e}, disabling clustering")
                self.use_clustering = False
        
        print("✅ Enhanced features initialized")

    def reset(self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) -> Tuple[np.ndarray, Dict[str, Any]]:
        """Reset the environment to an initial state."""
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)

        # Reset environment state
        self.current_step = 0
        self.episode_rewards = []
        self.episode_actions = []
        self.episode_contexts = []
        
        # Reset render state
        if self.fig is not None and plt is not None:
            plt.close(self.fig)
            self.fig = None

        # Set current user
        if options and options.get("user_id") is not None:
            self.current_user_id = options["user_id"]
        else:
            if self.data_source_type == "csv":
                self.current_user_id = random.choice(self.data[self.user_col])
            else:
                # Get random user from database
                users = self.data_controller.get_all_users()
                if len(users) > 0:
                    self.current_user_id = random.choice(users)['user_id']
                else:
                    self.current_user_id = 1  # Default user

        # Initialize user session
        if self.data_source_type == "database" and self.use_user_context:
            session_id = self.data_manager.start_user_session(self.current_user_id)
            self.current_session = session_id
        else:
            self.current_session = None

        # Get initial observation
        observation = self._get_observation()
        
        info = {
            'user_id': self.current_user_id,
            'session_id': self.current_session,
            'data_source_type': self.data_source_type,
            'enhanced_features': {
                'clustering': self.use_clustering,
                'user_context': self.use_user_context,
                'price_optimization': self.use_price_optimization
            }
        }
        
        # Add user state if available
        if self.data_source_type == "database" and self.use_user_context:
            try:
                user_state = self.data_manager.get_user_state(self.current_user_id)
                info['user_state'] = user_state
            except Exception as e:
                print(f"⚠️  Could not get user state: {e}")

        return observation, info

    def step(self, action: Union[np.ndarray, List[int]]) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
        """Execute one time step within the environment."""
        if self.current_step >= self.max_steps:
            return self._get_observation(), 0.0, True, False, {}

        # Convert action to list if needed
        if hasattr(action, 'tolist'):
            action = action.tolist()
        elif not isinstance(action, list):
            action = list(action)
        
        # Ensure all action values are integers
        action = [int(arm_idx) for arm_idx in action]
        
        # Validate action
        if len(action) != self.n_suggestions:
            raise ValueError(f"Action must contain exactly {self.n_suggestions} arm selections")
        
        # Execute step based on data source type
        if self.data_source_type == "csv":
            return self._step_csv(action)
        else:
            return self._step_database(action)

    def _step_csv(self, action: List[int]) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
        """Execute step for CSV data source."""
        # This implements the original RecommendationEnv step logic
        # (Implementation would be similar to the original step method)
        
        # For now, return a placeholder implementation
        total_reward = 0.0
        interactions_occurred = []
        
        for arm_idx in action:
            if arm_idx < self.n_products:
                # Simulate interaction
                reward = np.random.uniform(0, 1)  # Placeholder reward
                total_reward += reward
                interactions_occurred.append({
                    'product_id': arm_idx,
                    'reward': reward,
                    'event_type': 'view'
                })
        
        # Update environment state
        self.current_step += 1
        self.episode_rewards.append(total_reward)
        self.episode_actions.append(action)
        
        # Check if episode is done
        done = self.current_step >= self.max_steps
        
        # Get next observation
        observation = self._get_observation()
        
        info = {
            'interactions_occurred': interactions_occurred,
            'total_reward': total_reward,
            'step': self.current_step
        }
        
        return observation, total_reward, done, False, info

    def _step_database(self, action: List[int]) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
        """Execute step for database data source."""
        # Get available products
        available_products = self._get_available_products()
        if len(available_products) < self.n_suggestions:
            return self._get_observation(), -1.0, True, False, {'error': 'Not enough products available'}
        
        # Simulate user interactions with recommended products
        rewards = []
        interactions = []
        
        for product_idx in action:
            if product_idx < len(available_products):
                product = available_products[product_idx]
                reward, interaction = self._simulate_user_interaction(product)
                rewards.append(reward)
                interactions.append(interaction)
            else:
                rewards.append(-1.0)  # Invalid product index
                interactions.append({'type': 'invalid', 'product_id': None})
        
        # Calculate total reward
        total_reward = sum(rewards)
        
        # Update environment state
        self.current_step += 1
        self.episode_rewards.append(total_reward)
        self.episode_actions.append(action)
        
        # Update user state with new interactions
        if self.use_user_context and self.data_manager:
            for interaction in interactions:
                if interaction.get('product_id') is not None:
                    self.data_manager.update_user_state(
                        self.current_user_id, 
                        interaction
                    )
        
        # Check if episode is done
        done = self.current_step >= self.max_steps
        
        # Get next observation
        observation = self._get_observation()
        
        info = {
            'interactions': interactions,
            'total_reward': total_reward,
            'step': self.current_step,
            'available_products': len(available_products)
        }
        
        return observation, total_reward, done, False, info

    def _get_observation(self) -> np.ndarray:
        """Get the current observation of the environment."""
        if self.data_source_type == "csv":
            return self._get_observation_csv()
        else:
            return self._get_observation_database()

    def _get_observation_csv(self) -> np.ndarray:
        """Get observation for CSV data source."""
        # This would implement the original CSV observation logic
        # For now, return a placeholder
        return np.array([self.current_user_id, self.current_step, 0, 0, 0])

    def _get_observation_database(self) -> np.ndarray:
        """Get observation for database data source."""
        # This would implement the enhanced observation logic
        # For now, return a placeholder
        return np.array([self.current_user_id, self.current_step, 0, 0, 0])

    def _get_available_products(self) -> List[Dict[str, Any]]:
        """Get available products for recommendation."""
        if self.data_source_type == "csv":
            # Return product indices for CSV
            return list(range(min(self.n_products, self.max_arms)))
        else:
            # Get products from database with optional filtering
            try:
                products = self.data_controller.get_all_products(limit=self.max_arms)
                
                # Apply clustering filter if enabled
                if self.use_clustering and self.clustering_manager:
                    products = self._apply_clustering_filter(products)
                
                # Apply price optimization if enabled
                if self.use_price_optimization and self.clustering_manager:
                    products = self._apply_price_optimization(products)
                
                return products
            except Exception as e:
                print(f"⚠️  Error getting available products: {e}")
                return []

    def _apply_clustering_filter(self, products: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Apply clustering-based filtering to products."""
        try:
            if not self.clustering_manager:
                return products
            
            # Get user cluster
            user_cluster = self.clustering_manager.get_user_cluster(self.current_user_id)
            
            # Apply cluster-based filtering logic here
            # This is a simplified implementation
            return products[:self.max_arms]
            
        except Exception as e:
            print(f"⚠️  Error applying clustering filter: {e}")
            return products

    def _apply_price_optimization(self, products: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Apply price-based filtering to products."""
        try:
            if not self.clustering_manager:
                return products
            
            # Apply price optimization logic here
            # This is a simplified implementation
            return products[:self.max_arms]
            
        except Exception as e:
            print(f"⚠️  Error applying price optimization: {e}")
            return products

    def _simulate_user_interaction(self, product: Dict[str, Any]) -> Tuple[float, Dict[str, Any]]:
        """Simulate user interaction with a product."""
        # This is a simplified interaction simulation
        # In a real implementation, this would use more sophisticated user modeling
        
        # Simulate different interaction types
        interaction_types = ['view', 'cart', 'purchase', 'remove_from_cart']
        interaction_type = np.random.choice(interaction_types, p=[0.6, 0.2, 0.15, 0.05])
        
        # Calculate reward based on interaction type
        base_rewards = {'view': 0.1, 'cart': 0.33, 'purchase': 1.0, 'remove_from_cart': -0.5}
        reward = base_rewards.get(interaction_type, 0.0)
        
        # Add some randomness
        reward += np.random.normal(0, 0.1)
        
        interaction = {
            'type': interaction_type,
            'product_id': product.get('product_id'),
            'user_id': self.current_user_id,
            'timestamp': datetime.now().isoformat(),
            'reward': reward
        }
        
        return reward, interaction

    def render(self, mode: str = "human"):
        """Render the environment."""
        if mode == "human":
            # Implement human-readable rendering
            print(f"Step: {self.current_step}/{self.max_steps}")
            print(f"User: {self.current_user_id}")
            print(f"Total Reward: {sum(self.episode_rewards):.2f}")
        elif mode == "rgb_array":
            # Implement RGB array rendering for visualization
            pass

    def close(self):
        """Close the environment and clean up resources."""
        if self.fig is not None:
            plt.close(self.fig)
            self.fig = None
        
        # Close database connections if applicable
        if hasattr(self, 'data_controller') and self.data_controller:
            # Close database connections
            pass

    def get_algorithm_info(self) -> Dict[str, Any]:
        """Get information about the current algorithm state."""
        return {
            'data_source_type': self.data_source_type,
            'max_arms': self.max_arms,
            'n_suggestions': self.n_suggestions,
            'current_step': self.current_step,
            'current_user_id': self.current_user_id,
            'enhanced_features': {
                'clustering': self.use_clustering,
                'user_context': self.use_user_context,
                'price_optimization': self.use_price_optimization
            },
            'episode_stats': {
                'total_rewards': sum(self.episode_rewards),
                'avg_reward': np.mean(self.episode_rewards) if self.episode_rewards else 0,
                'total_steps': len(self.episode_actions)
            }
        }
