#!/usr/bin/env python3
"""
Comprehensive logging system for the MAB Recommender System.
Provides structured logging with different levels, file rotation, and performance tracking.
"""

import logging
import logging.handlers
import os
import sys
import time
import json
from datetime import datetime
from pathlib import Path
from functools import wraps
import traceback
import psutil
import threading
from typing import Optional, Dict, Any, Callable

class PerformanceLogger:
    """Tracks performance metrics and system resources."""
    
    def __init__(self):
        self.metrics = {}
        self.start_times = {}
        self._lock = threading.Lock()
    
    def start_timer(self, operation_name: str):
        """Start timing an operation."""
        with self._lock:
            self.start_times[operation_name] = time.time()
    
    def end_timer(self, operation_name: str) -> float:
        """End timing an operation and return duration."""
        with self._lock:
            if operation_name in self.start_times:
                duration = time.time() - self.start_times[operation_name]
                if operation_name not in self.metrics:
                    self.metrics[operation_name] = []
                self.metrics[operation_name].append(duration)
                del self.start_times[operation_name]
                return duration
            return 0.0
    
    def get_system_stats(self) -> Dict[str, Any]:
        """Get current system resource usage."""
        try:
            process = psutil.Process()
            return {
                'cpu_percent': process.cpu_percent(),
                'memory_percent': process.memory_percent(),
                'memory_rss': process.memory_info().rss / 1024 / 1024,  # MB
                'threads': process.num_threads(),
                'open_files': len(process.open_files()),
                'connections': len(process.connections())
            }
        except Exception:
            return {}
    
    def get_performance_summary(self) -> Dict[str, Any]:
        """Get summary of all performance metrics."""
        summary = {}
        for operation, times in self.metrics.items():
            if times:
                summary[operation] = {
                    'count': len(times),
                    'total_time': sum(times),
                    'avg_time': sum(times) / len(times),
                    'min_time': min(times),
                    'max_time': max(times)
                }
        return summary

class RecommenderLogger:
    """
    Main logging class for the recommender system.
    Provides structured logging with different levels and file rotation.
    """
    
    def __init__(self, 
                 name: str = "recommender_system",
                 log_dir: str = "/tmp/logs",
                 log_level: str = "INFO",
                 max_file_size: int = 10 * 1024 * 1024,  # 10MB
                 backup_count: int = 5,
                 enable_console: bool = True,
                 enable_file: bool = False,
                 enable_performance: bool = True):
        """
        Initialize the logger.
        
        Args:
            name: Logger name
            log_dir: Directory to store log files
            log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
            max_file_size: Maximum size of log file before rotation
            backup_count: Number of backup files to keep
            enable_console: Enable console logging
            enable_file: Enable file logging
            enable_performance: Enable performance tracking
        """
        self.name = name
        self.log_dir = Path(log_dir)
        self.log_level = getattr(logging, log_level.upper())
        self.max_file_size = max_file_size
        self.backup_count = backup_count
        self.enable_console = enable_console
        self.enable_file = enable_file
        self.enable_performance = enable_performance
        
        # Create log directory only if file logging is enabled
        if self.enable_file:
            self.log_dir.mkdir(parents=True, exist_ok=True)
        
        # Initialize performance logger
        if self.enable_performance:
            self.performance_logger = PerformanceLogger()
        else:
            self.performance_logger = None
        
        # Setup logger
        self.logger = self._setup_logger()
        
        # Log initialization
        self.info(f"Logger initialized: {name}", extra={
            'log_dir': str(self.log_dir),
            'log_level': logging.getLevelName(self.log_level),
            'enable_console': enable_console,
            'enable_file': enable_file,
            'enable_performance': enable_performance
        })
    
    def _setup_logger(self) -> logging.Logger:
        """Setup the logger with handlers and formatters."""
        logger = logging.getLogger(self.name)
        logger.setLevel(self.log_level)
        
        # Clear existing handlers
        logger.handlers.clear()
        
        # Create formatters
        detailed_formatter = logging.Formatter(
            '%(asctime)s | %(name)s | %(levelname)s | %(funcName)s:%(lineno)d | %(message)s'
        )
        
        simple_formatter = logging.Formatter(
            '%(asctime)s | %(levelname)s | %(message)s'
        )
        
        # Console handler
        if self.enable_console:
            console_handler = logging.StreamHandler(sys.stdout)
            console_handler.setLevel(self.log_level)
            console_handler.setFormatter(simple_formatter)
            logger.addHandler(console_handler)
        
        # File handlers
        if self.enable_file:
            # Main log file with rotation
            main_handler = logging.handlers.RotatingFileHandler(
                self.log_dir / f"{self.name}.log",
                maxBytes=self.max_file_size,
                backupCount=self.backup_count
            )
            main_handler.setLevel(self.log_level)
            main_handler.setFormatter(detailed_formatter)
            logger.addHandler(main_handler)
            
            # Error log file
            error_handler = logging.handlers.RotatingFileHandler(
                self.log_dir / f"{self.name}_errors.log",
                maxBytes=self.max_file_size,
                backupCount=self.backup_count
            )
            error_handler.setLevel(logging.ERROR)
            error_handler.setFormatter(detailed_formatter)
            logger.addHandler(error_handler)
            
            # Performance log file
            if self.enable_performance:
                perf_handler = logging.handlers.RotatingFileHandler(
                    self.log_dir / f"{self.name}_performance.log",
                    maxBytes=self.max_file_size,
                    backupCount=self.backup_count
                )
                perf_handler.setLevel(logging.INFO)
                perf_handler.setFormatter(detailed_formatter)
                logger.addHandler(perf_handler)
        
        return logger
    
    def _log_with_context(self, level: int, message: str, **kwargs):
        """Log message with additional context."""
        extra = {}
        if kwargs:
            extra = kwargs
        
        # Add system stats for performance logging
        if self.performance_logger:
            system_stats = self.performance_logger.get_system_stats()
            if system_stats:
                extra['system_stats'] = system_stats
        
        self.logger.log(level, message, extra=extra)
    
    def debug(self, message: str, **kwargs):
        """Log debug message."""
        self._log_with_context(logging.DEBUG, message, **kwargs)
    
    def info(self, message: str, **kwargs):
        """Log info message."""
        self._log_with_context(logging.INFO, message, **kwargs)
    
    def warning(self, message: str, **kwargs):
        """Log warning message."""
        self._log_with_context(logging.WARNING, message, **kwargs)
    
    def error(self, message: str, **kwargs):
        """Log error message."""
        self._log_with_context(logging.ERROR, message, **kwargs)
    
    def critical(self, message: str, **kwargs):
        """Log critical message."""
        self._log_with_context(logging.CRITICAL, message, **kwargs)
    
    def exception(self, message: str, **kwargs):
        """Log exception with traceback."""
        extra = kwargs.copy()
        extra['traceback'] = traceback.format_exc()
        self._log_with_context(logging.ERROR, message, **extra)
    
    def performance(self, operation_name: str, **kwargs):
        """Log performance-related information."""
        if self.performance_logger:
            self.performance_logger.start_timer(operation_name)
            self.info(f"Performance: Started {operation_name}", **kwargs)
    
    def performance_end(self, operation_name: str, **kwargs):
        """End performance tracking and log duration."""
        if self.performance_logger:
            duration = self.performance_logger.end_timer(operation_name)
            self.info(f"Performance: Completed {operation_name} in {duration:.4f}s", 
                     duration=duration, **kwargs)
    
    def log_algorithm_selection(self, algorithm_name: str, arms_selected: list, 
                               context: Optional[Dict] = None, **kwargs):
        """Log algorithm arm selection."""
        extra = {
            'algorithm': algorithm_name,
            'arms_selected': arms_selected,
            'context': context
        }
        extra.update(kwargs)
        self.info(f"Algorithm {algorithm_name} selected arms: {arms_selected}", **extra)
    
    def log_reward(self, algorithm_name: str, arms: list, rewards: list, 
                   cumulative_reward: float, step: int, **kwargs):
        """Log reward information."""
        extra = {
            'algorithm': algorithm_name,
            'arms': arms,
            'rewards': rewards,
            'cumulative_reward': cumulative_reward,
            'step': step
        }
        extra.update(kwargs)
        self.info(f"Reward received: {rewards} for arms {arms}, cumulative: {cumulative_reward:.4f}", **extra)
    
    def log_environment_reset(self, user_id: str, **kwargs):
        """Log environment reset."""
        extra = {'user_id': user_id}
        extra.update(kwargs)
        self.info(f"Environment reset for user: {user_id}", **extra)
    
    def log_simulation_start(self, algorithms: list, n_users: int, n_steps: int, **kwargs):
        """Log simulation start."""
        extra = {
            'algorithms': algorithms,
            'n_users': n_users,
            'n_steps': n_steps
        }
        extra.update(kwargs)
        self.info(f"Starting simulation with {len(algorithms)} algorithms, {n_users} users, {n_steps} steps", **extra)
    
    def log_simulation_complete(self, results: Dict, duration: float, **kwargs):
        """Log simulation completion."""
        extra = {
            'results_summary': {k: len(v) if isinstance(v, list) else str(type(v)) for k, v in results.items()},
            'duration': duration
        }
        extra.update(kwargs)
        self.info(f"Simulation completed in {duration:.4f}s", **extra)
    
    def log_config_change(self, config_name: str, old_value: Any, new_value: Any, **kwargs):
        """Log configuration changes."""
        extra = {
            'config_name': config_name,
            'old_value': old_value,
            'new_value': new_value
        }
        extra.update(kwargs)
        self.info(f"Configuration changed: {config_name} = {new_value} (was: {old_value})", **extra)
    
    def get_performance_summary(self) -> Dict[str, Any]:
        """Get performance summary."""
        if self.performance_logger:
            return self.performance_logger.get_performance_summary()
        return {}
    
    def export_logs(self, output_file: str = None) -> str:
        """Export logs to a file."""
        if output_file is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_file = self.log_dir / f"logs_export_{timestamp}.json"
        
        export_data = {
            'export_timestamp': datetime.now().isoformat(),
            'logger_name': self.name,
            'performance_summary': self.get_performance_summary(),
            'log_files': []
        }
        
        # List log files
        for log_file in self.log_dir.glob(f"{self.name}*.log*"):
            export_data['log_files'].append({
                'filename': log_file.name,
                'size_bytes': log_file.stat().st_size,
                'modified': datetime.fromtimestamp(log_file.stat().st_mtime).isoformat()
            })
        
        with open(output_file, 'w') as f:
            json.dump(export_data, f, indent=2)
        
        self.info(f"Logs exported to: {output_file}")
        return str(output_file)

def performance_tracker(operation_name: str = None):
    """
    Decorator to automatically track performance of functions.
    
    Usage:
        @performance_tracker("my_operation")
        def my_function():
            pass
    """
    def decorator(func: Callable):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # Get logger from args if available, otherwise use default
            logger = None
            for arg in args:
                if hasattr(arg, 'logger'):
                    logger = arg.logger
                    break
            
            if logger is None:
                # Try to get logger from class instance
                if args and hasattr(args[0], 'logger'):
                    logger = args[0].logger
                else:
                    # Use default logger
                    logger = get_default_logger()
            
            op_name = operation_name or func.__name__
            logger.performance(op_name)
            
            try:
                result = func(*args, **kwargs)
                logger.performance_end(op_name)
                return result
            except Exception as e:
                logger.exception(f"Error in {op_name}: {str(e)}")
                raise
        
        return wrapper
    return decorator

def get_default_logger() -> RecommenderLogger:
    """Get the default logger instance."""
    if not hasattr(get_default_logger, '_instance'):
        get_default_logger._instance = RecommenderLogger()
    return get_default_logger._instance

def setup_logging(config: Dict[str, Any] = None) -> RecommenderLogger:
    """Setup logging with configuration."""
    if config is None:
        config = {}
    
    default_config = {
        'name': 'recommender_system',
        'log_dir': '/tmp/logs',
        'log_level': 'INFO',
        'max_file_size': 10 * 1024 * 1024,
        'backup_count': 5,
        'enable_console': True,
        'enable_file': False,
        'enable_performance': True
    }
    
    # Update with provided config
    default_config.update(config)
    
    return RecommenderLogger(**default_config)

# Create default logger instance
default_logger = setup_logging()
