import numpy as np
import random
from collections import defaultdict
from .logger import RecommenderLogger, performance_tracker

class MABAlgorithm:
    """Base class for MAB algorithms."""
    
    def __init__(self, n_arms, n_suggestions=1):
        self.n_arms = n_arms
        self.n_suggestions = n_suggestions
        self.logger = RecommenderLogger(name=f"{self.__class__.__name__}_{n_arms}arms")
        self.reset()
        
        # Log algorithm initialization
        self.logger.info(f"Algorithm {self.__class__.__name__} initialized", extra={
            'n_arms': n_arms,
            'n_suggestions': n_suggestions
        })

    def reset(self):
        """Reset the algorithm's state."""
        self.arm_counts = np.zeros(self.n_arms)
        self.arm_rewards = np.zeros(self.n_arms)
        self.total_steps = 0
        self.total_pulls = 0

    def select_arm(self):
        """Select an arm to pull."""
        raise NotImplementedError("Subclasses should implement this method.")

    def select_multiple_arms(self, input_arms_idx=None):
        """Select multiple arms to pull simultaneously."""
        raise NotImplementedError("Subclasses should implement this method.")
    
    def _log_arm_selection(self, arms_selected, context=None):
        """Log arm selection for tracking."""
        self.logger.log_algorithm_selection(
            algorithm_name=self.__class__.__name__,
            arms_selected=arms_selected.tolist() if hasattr(arms_selected, 'tolist') else arms_selected,
            context=context,
            n_arms=self.n_arms,
            n_suggestions=self.n_suggestions
        )

    def update(self, arms, rewards):
        """Update the algorithm's state after pulling multiple arms."""
        if isinstance(arms, (int, np.integer)):
            # Single arm update (backward compatibility)
            self.arm_counts[arms] += 1
            self.arm_rewards[arms] += rewards
            self.total_steps += 1
            self.total_pulls += 1
            # Ensure total_pulls never goes negative
            self.total_pulls = max(self.total_pulls, 0)
            
            # Log single arm update
            self.logger.log_reward(
                algorithm_name=self.__class__.__name__,
                arms=[arms],
                rewards=[rewards],
                cumulative_reward=self.arm_rewards[arms],
                step=self.total_steps,
                arm_counts=self.arm_counts[arms]
            )
        else:
            # Multiple arms update
            for arm, reward in zip(arms, rewards):
                arm = int(arm)  # Ensure arm is an integer
                self.arm_counts[arm] += 1
                self.arm_rewards[arm] += reward
            
            self.total_steps += 1
            self.total_pulls += len(arms)
            # Ensure total_pulls never goes negative
            self.total_pulls = max(self.total_pulls, 0)
            
            # Log multiple arms update
            self.logger.log_reward(
                algorithm_name=self.__class__.__name__,
                arms=arms.tolist() if hasattr(arms, 'tolist') else arms,
                rewards=rewards,
                cumulative_reward=sum(self.arm_rewards),
                step=self.total_steps,
                arm_counts=[self.arm_counts[arm] for arm in arms]
            )

class RandomMAB(MABAlgorithm):
    """
    Random MAB algorithm.
    """
    def __init__(self, n_arms, n_suggestions=1, seed=None):
        super().__init__(n_arms, n_suggestions)
        self.seed = seed

    def select_arm(self, input_arms_idx, seed=None):
        # choose random arm with random seed provided
        return np.random.choice(input_arms_idx) if (seed is None and self.seed is None) else np.random.RandomState(seed if seed is not None else self.seed).choice(input_arms_idx)

    def select_multiple_arms(self, input_arms_idx=None, seed=None):
        # choose random arms without replacement
        if seed is None and self.seed is None:
            if input_arms_idx is None:
                arms = np.random.choice(self.n_arms, size=self.n_suggestions, replace=False)
            else:
                arms = np.random.choice(input_arms_idx, size=self.n_suggestions, replace=False)
        else:
            if input_arms_idx is None:
                arms = np.random.RandomState(seed if seed is not None else self.seed).choice(self.n_arms, size=self.n_suggestions, replace=False)
            else:
                arms = np.random.RandomState(seed if seed is not None else self.seed).choice(input_arms_idx, size=self.n_suggestions, replace=False)
        
        # Ensure arms are integers
        arms = arms.astype(np.int64)

        print(f"RandomMAB: {arms}, {input_arms_idx}")
        
        # Log arm selection
        self._log_arm_selection(arms, context={'seed': seed})
        # if input_arms_idx is None:
        #     return arms
        # else:
        #     return input_arms_idx[arms]
        return arms

class EpsilonGreedyMAB(MABAlgorithm):
    """Epsilon-Greedy algorithm for MAB."""
    
    def __init__(self, n_arms, n_suggestions=1, epsilon=0.1):
        super().__init__(n_arms, n_suggestions)
        self.epsilon = epsilon

    def select_arm(self, input_arms_idx=None):
        if np.random.rand() < self.epsilon:
            if input_arms_idx is None:
                return np.random.choice(self.n_arms)  # Explore
            else:
                return np.random.choice(input_arms_idx)  # Explore
        else:
            avg_rewards = np.divide(self.arm_rewards, self.arm_counts, out=np.zeros_like(self.arm_rewards), where=self.arm_counts != 0)
            if input_arms_idx is None:
                return np.argmax(avg_rewards)  # Exploit
            else:
                return input_arms_idx[np.argmax(avg_rewards[input_arms_idx])]  # Exploit

    def select_multiple_arms(self, input_arms_idx=None):
        if np.random.rand() < self.epsilon:
            # Explore: select random arms without replacement
            if input_arms_idx is None:
                arms = np.random.choice(self.n_arms, size=self.n_suggestions, replace=False)
            else:
                arms = np.random.choice(input_arms_idx, size=self.n_suggestions, replace=False)
            strategy = "explore"
        else:
            # Exploit: select top n_suggestions arms based on average rewards
            avg_rewards = np.divide(self.arm_rewards, self.arm_counts, out=np.zeros_like(self.arm_rewards), where=self.arm_counts != 0)
            # Get indices of top n_suggestions arms
            if input_arms_idx is None:
                top_arms = np.argsort(avg_rewards)[-self.n_suggestions:]
                arms = top_arms[::-1]  # Return in descending order
            else:
                top_arms = np.argsort(avg_rewards[input_arms_idx])[-self.n_suggestions:]
                arms = top_arms[::-1]  # Return in descending order
            
            strategy = "exploit"
        
        # Ensure arms are integers
        arms = arms.astype(np.int64)
        
        # Log arm selection with strategy
        self._log_arm_selection(arms, context={
            'strategy': strategy,
            'epsilon': self.epsilon,
            'avg_rewards': avg_rewards.tolist() if 'avg_rewards' in locals() else None
        })

        print(f"EpsilonGreedyMAB: {arms}, {input_arms_idx}")

        # if input_arms_idx is None:
        #     return arms
        # else:
        #     return input_arms_idx[arms]
        return arms

class UpperConfidenceBoundMAB(MABAlgorithm):
    """Upper Confidence Bound algorithm for MAB."""

    def __init__(self, n_arms, n_suggestions=1, c=2.0):
        super().__init__(n_arms, n_suggestions)
        self.c = c

    def select_arm(self, input_arms_idx=None):
        if input_arms_idx is None:
            for arm in range(self.n_arms):
                if self.arm_counts[arm] == 0:
                    return arm
        else:
            for arm in input_arms_idx:
                if self.arm_counts[arm] == 0:
                    return arm
        
        avg_rewards = np.divide(self.arm_rewards, self.arm_counts, out=np.zeros_like(self.arm_rewards), where=self.arm_counts != 0)
        # Ensure total_pulls is at least 1 to avoid log(0) issues
        safe_total_pulls = max(self.total_pulls, 1)
        confidence = self.c * np.sqrt(np.divide(np.log(safe_total_pulls), self.arm_counts, out=np.zeros_like(self.arm_counts), where=self.arm_counts != 0))
        ucb_values = avg_rewards + confidence

        if input_arms_idx is None:
            return np.argmax(ucb_values)
        else:
            return input_arms_idx[np.argmax(ucb_values[input_arms_idx])]

    def select_multiple_arms(self, input_arms_idx=None):
        # First, select any unexplored arms
        unexplored_arms = []
        if input_arms_idx is None:
            for arm in range(self.n_arms):
                if self.arm_counts[arm] == 0:
                    unexplored_arms.append(arm)
                    if len(unexplored_arms) >= self.n_suggestions:
                        arms = np.array(unexplored_arms[:self.n_suggestions], dtype=np.int64)
                        return arms
                    else:
                        return np.array(unexplored_arms, dtype=np.int64)
        
        # If we have unexplored arms but not enough, fill the rest with UCB
        remaining_slots = self.n_suggestions - len(unexplored_arms)
        
        if remaining_slots > 0:
            avg_rewards = np.divide(self.arm_rewards, self.arm_counts, out=np.zeros_like(self.arm_rewards), where=self.arm_counts != 0)
            # Ensure total_pulls is at least 1 to avoid log(0) issues
            safe_total_pulls = max(self.total_pulls, 1)
            confidence = self.c * np.sqrt(np.divide(np.log(safe_total_pulls), self.arm_counts, out=np.zeros_like(self.arm_counts), where=self.arm_counts != 0))
            if input_arms_idx is None:
                ucb_values = avg_rewards + confidence
            else:
                ucb_values = avg_rewards[input_arms_idx] + confidence[input_arms_idx]
            
            # Get top remaining_slots arms based on UCB values
            if input_arms_idx is None:
                top_arms = np.argsort(ucb_values)[-remaining_slots:]
                # Ensure all arms are integers and concatenate properly
                all_arms = np.concatenate([unexplored_arms, top_arms[::-1]]).astype(np.int64)
                return all_arms
            else:
                top_arms = np.argsort(ucb_values[input_arms_idx])[-remaining_slots:]
                # Ensure all arms are integers and concatenate properly
                all_arms = np.concatenate([unexplored_arms, top_arms[::-1]]).astype(np.int64)
                return input_arms_idx[all_arms]
        
        arms = np.array(unexplored_arms, dtype=np.int64)
        # if input_arms_idx is None:
        #     return arms
        # else:
        #     return input_arms_idx[arms]
        return arms

class ThompsonSamplingMAB(MABAlgorithm):
    """Thompson Sampling algorithm for MAB."""

    def __init__(self, n_arms, n_suggestions=1, alpha=1.0, beta=1.0):
        super().__init__(n_arms, n_suggestions)
        self.alpha = alpha
        self.beta = beta
        self.successes = np.ones(n_arms) * alpha
        self.failures = np.ones(n_arms) * beta

    def select_arm(self, input_arms_idx=None):
        sampled_theta = np.random.beta(self.successes, self.failures)
        if input_arms_idx is None:
            return np.argmax(sampled_theta)
        else:
            return input_arms_idx[np.argmax(sampled_theta[input_arms_idx])]

    def select_multiple_arms(self, input_arms_idx=None):
        sampled_theta = np.random.beta(self.successes, self.failures)
        # Get top n_suggestions arms based on sampled theta values
        if input_arms_idx is None:
            top_arms = np.argsort(sampled_theta)[-self.n_suggestions:]
            arms = top_arms[::-1]  # Return in descending order
        else:
            top_arms = np.argsort(sampled_theta[input_arms_idx])[-self.n_suggestions:]
        arms = top_arms[::-1]  # Return in descending order
        # Ensure arms are integers
        arms = arms.astype(np.int64)
        # return input_arms_idx[arms]
        return arms

    def update(self, arms, rewards):
        super().update(arms, rewards)
        
        if isinstance(arms, (int, np.integer)):
            # Single arm update (backward compatibility)
            if rewards > 0:
                self.successes[arms] += 1
            else:
                self.failures[arms] += 1
        else:
            # Multiple arms update
            for arm, reward in zip(arms, rewards):
                if reward > 0:
                    self.successes[arm] += 1
                else:
                    self.failures[arm] += 1

class ContextualMAB(MABAlgorithm):
    """Contextual Bandit algorithm for MAB."""

    def __init__(self, n_arms, n_suggestions=1, epsilon=0.1):
        super().__init__(n_arms, n_suggestions)
        self.epsilon = epsilon
        self.context_rewards = defaultdict(lambda: np.zeros(n_arms))
        self.context_counts = defaultdict(lambda: np.zeros(n_arms))

    def select_arm(self, context=None, input_arms_idx=None):
        if context is None:
            context = "default"

        context_key = self._create_context_key(context)

        if random.random() < self.epsilon:
            if input_arms_idx is None:
                return random.randrange(self.n_arms)
            else:
                return input_arms_idx[random.randrange(len(input_arms_idx))]
        else:
            context_avg = np.divide(
                self.context_rewards[context_key],
                self.context_counts[context_key],
                out=np.zeros(self.n_arms),
                where=self.context_counts[context_key] != 0
            )

            if np.sum(context_avg) == 0:
                global_avg = np.divide(self.arm_rewards, self.arm_counts, out=np.zeros_like(self.arm_rewards), where=self.arm_counts != 0)
                if input_arms_idx is None:
                    return np.argmax(global_avg)
                else:
                    return input_arms_idx[np.argmax(global_avg[input_arms_idx])]
            if input_arms_idx is None:
                return np.argmax(context_avg)
            else:
                return input_arms_idx[np.argmax(context_avg[input_arms_idx])]

    def select_multiple_arms(self, context=None, input_arms_idx=None):
        if context is None:
            context = "default"

        context_key = self._create_context_key(context)

        if random.random() < self.epsilon:
            if input_arms_idx is None:
                arms = np.random.choice(self.n_arms, size=self.n_suggestions, replace=False)
            else:
                arms = np.random.choice(input_arms_idx, size=self.n_suggestions, replace=False)
            arms = arms.astype(np.int64)

            # if input_arms_idx is None:
            #     return arms
            # else:
            #     return input_arms_idx[arms]
            return arms
        else:
            context_avg = np.divide(
                self.context_rewards[context_key],  
                self.context_counts[context_key],
                out=np.zeros(self.n_arms),
                where=self.context_counts[context_key] != 0
            )

            if np.sum(context_avg) == 0:
                global_avg = np.divide(self.arm_rewards, self.arm_counts, out=np.zeros_like(self.arm_rewards), where=self.arm_counts != 0)
                # Get top n_suggestions arms based on global average rewards
                if input_arms_idx is None:
                    top_arms = np.argsort(global_avg)[-self.n_suggestions:]
                    arms = top_arms[::-1]
                    arms = arms.astype(np.int64)
                    return arms
                else:
                    top_arms = np.argsort(global_avg[input_arms_idx])[-self.n_suggestions:]
                    arms = top_arms[::-1]
                    arms = arms.astype(np.int64)
                    return input_arms_idx[arms]
            
            # Get top n_suggestions arms based on context average rewards
            if input_arms_idx is None:
                top_arms = np.argsort(context_avg)[-self.n_suggestions:]
                arms = top_arms[::-1]
                arms = arms.astype(np.int64)
                return arms
            else:
                top_arms = np.argsort(context_avg[input_arms_idx])[-self.n_suggestions:]
                arms = top_arms[::-1]
                arms = arms.astype(np.int64)
                return input_arms_idx[arms]

    def update(self, arms, rewards, context=None): # Check this
        super().update(arms, rewards)
        if context is None:
            context = "default"

        context_key = self._create_context_key(context)

        if isinstance(arms, (int, np.integer)):
            # Single arm update (backward compatibility)
            self.context_rewards[context_key][arms] += rewards
            self.context_counts[context_key][arms] += 1
        else:
            # Multiple arms update
            for arm, reward in zip(arms, rewards):
                self.context_rewards[context_key][arm] += reward
                self.context_counts[context_key][arm] += 1

    def _create_context_key(self, context):
        """Create a simple context key from the context."""
        
        if isinstance(context, dict):
            return f"step_{min(context.get('session_step', 0), 5)}"
        return str(context)

# ============================================================================
# ENSEMBLE MAB ALGORITHMS
# ============================================================================

class EnsembleMAB(MABAlgorithm):
    """Base class for ensemble MAB algorithms."""
    
    def __init__(self, n_arms, n_suggestions=1, algorithms=None):
        # Initialize algorithms first before calling super().__init__()
        self.algorithms = algorithms or []
        self.algorithm_weights = np.ones(len(self.algorithms)) / len(self.algorithms) if self.algorithms else []
        self.algorithm_performances = np.zeros(len(self.algorithms)) if self.algorithms else []
        
        # Now call super().__init__() which will call reset()
        super().__init__(n_arms, n_suggestions)
        
    def add_algorithm(self, algorithm):
        """Add an algorithm to the ensemble."""
        self.algorithms.append(algorithm)
        if self.algorithms:
            self.algorithm_weights = np.ones(len(self.algorithms)) / len(self.algorithms)
            self.algorithm_performances = np.zeros(len(self.algorithms))
        
    def reset(self):
        """Reset the ensemble and all algorithms."""
        super().reset()
        if self.algorithms:
            for algorithm in self.algorithms:
                algorithm.reset()
            self.algorithm_weights = np.ones(len(self.algorithms)) / len(self.algorithms)
            self.algorithm_performances = np.zeros(len(self.algorithms))

class VotingEnsembleMAB(EnsembleMAB):
    """
    Voting-based ensemble MAB algorithm.
    Each algorithm votes for arms, and the most voted arms are selected.
    """
    
    def __init__(self, n_arms, n_suggestions=1, algorithms=None, voting_method='majority'):
        super().__init__(n_arms, n_suggestions, algorithms)
        self.voting_method = voting_method  # 'majority', 'weighted', 'ranked'
        
    def select_multiple_arms(self, context=None, input_arms_idx=None):
        if not self.algorithms:
            if input_arms_idx is None:
                arms = np.random.choice(self.n_arms, size=self.n_suggestions, replace=False)
                arms = arms.astype(np.int64)
                return arms
            else:
                arms = np.random.choice(input_arms_idx, size=self.n_suggestions, replace=False)
                arms = arms.astype(np.int64)
                return input_arms_idx[arms]
            
        # Get votes from all algorithms
        all_votes = []
        for algorithm in self.algorithms:
            if hasattr(algorithm, 'select_multiple_arms'):
                if context is not None and hasattr(algorithm, 'select_multiple_arms'):
                    # Check if algorithm supports context
                    try:
                        votes = algorithm.select_multiple_arms(context, input_arms_idx)
                    except TypeError:
                        votes = algorithm.select_multiple_arms(input_arms_idx)
                else:
                    votes = algorithm.select_multiple_arms(input_arms_idx)
                all_votes.append(votes)
            else:
                # Fallback to single arm selection
                votes = [algorithm.select_arm()]
                all_votes.append(votes)
        
        # Count votes for each arm
        arm_votes = np.zeros(self.n_arms)
        for votes in all_votes:
            for vote in votes:
                arm_votes[int(vote)] += 1
        
        # Select arms based on voting method
        if self.voting_method == 'majority':
            # Select arms with most votes
            selected_arms = np.argsort(arm_votes)[-self.n_suggestions:][::-1]
        elif self.voting_method == 'weighted':
            # Weight votes by algorithm performance
            weighted_votes = np.zeros(self.n_arms)
            for i, votes in enumerate(all_votes):
                weight = self.algorithm_weights[i]
                for vote in votes:
                    weighted_votes[int(vote)] += weight
            selected_arms = np.argsort(weighted_votes)[-self.n_suggestions:][::-1]
        else:  # ranked
            # Use ranked voting
            selected_arms = self._ranked_voting(all_votes)
            
        # Ensure arms are integers
        selected_arms = selected_arms.astype(np.int64)
        return selected_arms
    
    def _ranked_voting(self, all_votes):
        """Implement ranked voting system."""
        arm_scores = np.zeros(self.n_arms)
        
        for votes in all_votes:
            for rank, vote in enumerate(votes):
                # Higher rank (lower index) gets higher score
                arm_scores[int(vote)] += 1.0 / (rank + 1)
        
        return np.argsort(arm_scores)[-self.n_suggestions:][::-1]

class WeightedEnsembleMAB(EnsembleMAB):
    """
    Weighted ensemble MAB algorithm.
    Combines arm selection probabilities from all algorithms using weighted averaging.
    """
    
    def __init__(self, n_arms, n_suggestions=1, algorithms=None, weight_update_rate=0.01):
        super().__init__(n_arms, n_suggestions, algorithms)
        self.weight_update_rate = weight_update_rate
        
    def select_multiple_arms(self, context=None, input_arms_idx=None):
        if not self.algorithms:
            if input_arms_idx is None:
                arms = np.random.choice(self.n_arms, size=self.n_suggestions, replace=False)
            else:
                arms = np.random.choice(input_arms_idx, size=self.n_suggestions, replace=False)
            arms = arms.astype(np.int64)
            return arms
        
        # Calculate combined arm scores
        arm_scores = np.zeros(self.n_arms)
        
        for i, algorithm in enumerate(self.algorithms):
            weight = self.algorithm_weights[i]
            
            # Get algorithm's arm selection
            if hasattr(algorithm, 'select_multiple_arms'):
                if context is not None and hasattr(algorithm, 'select_multiple_arms'):
                    try:
                        selected_arms = algorithm.select_multiple_arms(context, input_arms_idx)
                    except TypeError:
                        selected_arms = algorithm.select_multiple_arms(input_arms_idx)
                else:
                    selected_arms = algorithm.select_multiple_arms(input_arms_idx)
            else:
                selected_arms = [algorithm.select_arm(input_arms_idx)]
            
            # Add weighted scores
            for arm in selected_arms:
                arm_scores[int(arm)] += weight
        
        # Select top arms
        selected_arms = np.argsort(arm_scores)[-self.n_suggestions:][::-1]
        # Ensure arms are integers
        selected_arms = selected_arms.astype(np.int64)
        if input_arms_idx is None:
            return selected_arms
        else:
            return input_arms_idx[selected_arms]
    
    def update(self, arms, rewards, context=None):
        """Update ensemble and adjust algorithm weights based on performance."""
        super().update(arms, rewards)
        
        # Update all algorithms
        for algorithm in self.algorithms:
            if hasattr(algorithm, 'update'):
                if context is not None and hasattr(algorithm, 'update'):
                    try:
                        algorithm.update(arms, rewards, context)
                    except TypeError:
                        algorithm.update(arms, rewards)
                else:
                    algorithm.update(arms, rewards)
        
        # Update algorithm weights based on performance
        self._update_weights(rewards)
    
    def _update_weights(self, rewards):
        """Update algorithm weights based on recent performance."""
        if not self.algorithms:
            return
            
        # Calculate performance for each algorithm
        for i, algorithm in enumerate(self.algorithms):
            if hasattr(algorithm, 'arm_rewards') and hasattr(algorithm, 'arm_counts'):
                # Use average reward as performance metric
                avg_reward = np.mean(algorithm.arm_rewards / np.maximum(algorithm.arm_counts, 1))
                self.algorithm_performances[i] = avg_reward
        
        # Update weights using softmax
        if np.sum(self.algorithm_performances) > 0:
            exp_performances = np.exp(self.algorithm_performances)
            new_weights = exp_performances / np.sum(exp_performances)
            
            # Smooth weight update
            self.algorithm_weights = (1 - self.weight_update_rate) * self.algorithm_weights + \
                                   self.weight_update_rate * new_weights

class DynamicEnsembleMAB(EnsembleMAB):
    """
    Dynamic ensemble MAB algorithm.
    Dynamically selects the best performing algorithm based on recent performance.
    """
    
    def __init__(self, n_arms, n_suggestions=1, algorithms=None, window_size=50, switch_threshold=0.1):
        super().__init__(n_arms, n_suggestions, algorithms)
        self.window_size = window_size
        self.switch_threshold = switch_threshold
        self.algorithm_rewards = [[] for _ in range(len(self.algorithms))] if self.algorithms else []
        self.current_best_algorithm = 0 if self.algorithms else None
        
    def select_multiple_arms(self, context=None, input_arms_idx=None):
        if not self.algorithms:
            if input_arms_idx is None:
                arms = np.random.choice(self.n_arms, size=self.n_suggestions, replace=False)
                arms = arms.astype(np.int64)
                return arms
            else:
                arms = np.random.choice(input_arms_idx, size=self.n_suggestions, replace=False)
            arms = arms.astype(np.int64)
            return arms
        
        # Select arms using the best performing algorithm
        best_algorithm = self.algorithms[self.current_best_algorithm]
        
        if hasattr(best_algorithm, 'select_multiple_arms'):
            if context is not None and hasattr(best_algorithm, 'select_multiple_arms'):
                try:
                    return best_algorithm.select_multiple_arms(context, input_arms_idx)
                except TypeError:
                    return best_algorithm.select_multiple_arms(input_arms_idx)
            else:
                return best_algorithm.select_multiple_arms(input_arms_idx)
        else:
            return [best_algorithm.select_arm(input_arms_idx)]
    
    def update(self, arms, rewards, context=None):
        """Update ensemble and check if algorithm switching is needed."""
        super().update(arms, rewards)
        
        # Update all algorithms
        for i, algorithm in enumerate(self.algorithms):
            if hasattr(algorithm, 'update'):
                if context is not None and hasattr(algorithm, 'update'):
                    try:
                        algorithm.update(arms, rewards, context)
                    except TypeError:
                        algorithm.update(arms, rewards)
                else:
                    algorithm.update(arms, rewards)
            
            # Store reward for performance tracking
            self.algorithm_rewards[i].append(rewards)
            if len(self.algorithm_rewards[i]) > self.window_size:
                self.algorithm_rewards[i].pop(0)
        
        # Check if we should switch algorithms
        self._check_algorithm_switch()
    
    def _check_algorithm_switch(self):
        """Check if we should switch to a better performing algorithm."""
        if len(self.algorithms) < 2:
            return
            
        # Calculate recent performance for each algorithm
        performances = []
        for rewards in self.algorithm_rewards:
            if rewards:
                avg_performance = np.mean(rewards)
                performances.append(avg_performance)
            else:
                performances.append(0.0)
        
        # Find best performing algorithm
        best_idx = np.argmax(performances)
        best_performance = performances[best_idx]
        current_performance = performances[self.current_best_algorithm]
        
        # Switch if the best algorithm is significantly better
        if best_performance > current_performance + self.switch_threshold:
            self.current_best_algorithm = best_idx

class ExpertEnsembleMAB(EnsembleMAB):
    """
    Expert ensemble MAB algorithm.
    Uses exponential weighting to combine expert algorithms.
    """
    
    def __init__(self, n_arms, n_suggestions=1, algorithms=None, learning_rate=0.1):
        super().__init__(n_arms, n_suggestions, algorithms)
        self.learning_rate = learning_rate
        self.algorithm_losses = np.zeros(len(self.algorithms)) if self.algorithms else []
        
    def select_multiple_arms(self, context=None, input_arms_idx=None):
        if not self.algorithms:
            if input_arms_idx is None:
                arms = np.random.choice(self.n_arms, size=self.n_suggestions, replace=False)
                arms = arms.astype(np.int64)
                return arms
            else:
                arms = np.random.choice(input_arms_idx, size=self.n_suggestions, replace=False)
            arms = arms.astype(np.int64)
            return arms
        
        # Calculate expert weights
        weights = np.exp(-self.learning_rate * self.algorithm_losses)
        weights = weights / np.sum(weights)
        
        # Get arm predictions from all experts
        arm_scores = np.zeros(self.n_arms)
        
        for i, algorithm in enumerate(self.algorithms):
            weight = weights[i]
            
            if hasattr(algorithm, 'select_multiple_arms'):
                if context is not None and hasattr(algorithm, 'select_multiple_arms'):
                    try:
                        selected_arms = algorithm.select_multiple_arms(context, input_arms_idx)
                    except TypeError:
                        selected_arms = algorithm.select_multiple_arms(input_arms_idx)
                else:
                    selected_arms = algorithm.select_multiple_arms(input_arms_idx)
            else:
                selected_arms = [algorithm.select_arm(input_arms_idx)]
            
            # Add weighted scores
            for arm in selected_arms:
                arm_scores[int(arm)] += weight
        
        # Select top arms
        selected_arms = np.argsort(arm_scores)[-self.n_suggestions:][::-1]
        # Ensure arms are integers
        selected_arms = selected_arms.astype(np.int64)
        if input_arms_idx is None:
            return selected_arms
        else:
            return input_arms_idx[selected_arms]
    
    def update(self, arms, rewards, context=None):
        """Update ensemble and expert losses."""
        super().update(arms, rewards)
        
        # Update all algorithms
        for algorithm in self.algorithms:
            if hasattr(algorithm, 'update'):
                if context is not None and hasattr(algorithm, 'update'):
                    try:
                        algorithm.update(arms, rewards, context)
                    except TypeError:
                        algorithm.update(arms, rewards)
                else:
                    algorithm.update(arms, rewards)
        
        # Update expert losses
        self._update_expert_losses(arms, rewards)
    
    def _update_expert_losses(self, arms, rewards):
        """Update expert losses based on their predictions."""
        if not self.algorithms:
            return
            
        # Calculate loss for each expert
        for i, algorithm in enumerate(self.algorithms):
            # Get expert's prediction
            if hasattr(algorithm, 'select_multiple_arms'):
                if hasattr(algorithm, 'select_multiple_arms'):
                    try:
                        predicted_arms = algorithm.select_multiple_arms()
                    except:
                        predicted_arms = [algorithm.select_arm()]
                else:
                    predicted_arms = algorithm.select_multiple_arms()
            else:
                predicted_arms = [algorithm.select_arm()]
            
            # Calculate loss (negative reward for incorrect predictions)
            loss = 0
            for arm in arms:
                if arm not in predicted_arms:
                    loss += 1  # Penalty for not selecting the actual arm
            
            self.algorithm_losses[i] += loss
