import gymnasium as gym
import numpy as np
import pandas as pd
try:
    import matplotlib.pyplot as plt
    MATPLOTLIB_AVAILABLE = True
except Exception:
    plt = None
    MATPLOTLIB_AVAILABLE = False
import warnings
import time
from collections import Counter
from rl_recommender.mabalgorithms import (
	RandomMAB, 
	EpsilonGreedyMAB, 
	UpperConfidenceBoundMAB, 
	ThompsonSamplingMAB, 
	ContextualMAB,
	VotingEnsembleMAB,
	WeightedEnsembleMAB,
	DynamicEnsembleMAB,
	ExpertEnsembleMAB
)
from .logger import RecommenderLogger, performance_tracker

# ignore warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Register the custom environment
try:
    gym.register(
        id="gymnasium_env/RecommendationMAB",
        entry_point="rl_recommender.UnifiedRecommendationEnv:UnifiedRecommendationEnv",
    )
except Exception as e:
    # If registration fails, we'll handle it gracefully
    pass

# env = gym.make("gymnasium_env/RecommendationMAB", data="dataset/filtered_data.csv", max_arms=10, seed=24500)

# obs, info = env.reset()

# # print(obs)

# # env.render()

# # env.unwrapped.set_render_mode("human")

# done = False
# while not done:
# 	action = env.action_space.sample()
# 	obs, reward, terminated, truncated, info = env.step(action)
# 	done = terminated or truncated

# print(f"Done")

# Multi-Armed Bandit Algorithms are now imported from mabalgorithms.py

class MABSimulation:
	"""Framework for running MAB experiments."""

	def __init__(self, environment, algorithm_type="epsilon_greedy", algorithm_params=None):
		self.environment = environment
		self.results = {}
		self.logger = RecommenderLogger(name="mab_simulation")
		self.current_algorithm = None
		self.algorithm_type = algorithm_type
		self.algorithm_params = algorithm_params or {}
		
		# Initialize default algorithm
		self._initialize_algorithm()
		
		# Log simulation framework initialization
		self.logger.info("MAB Simulation framework initialized", extra={
			'environment_type': type(environment).__name__,
			'environment_config': {
				'max_arms': getattr(environment, 'max_arms', 'N/A'),
				'n_users': getattr(environment, 'n_users', 'N/A'),
				'n_suggestions': getattr(environment, 'n_suggestions', 'N/A')
			},
			'algorithm_type': algorithm_type,
			'algorithm_params': algorithm_params
		})
	
	def _initialize_algorithm(self):
		"""Initialize the default algorithm based on type."""
		try:
			# Get environment parameters with fallbacks
			n_arms = getattr(self.environment, 'max_arms', 50)
			n_suggestions = getattr(self.environment, 'n_suggestions', 3)
			
			# Ensure we have valid values
			if not isinstance(n_arms, int) or n_arms <= 0:
				n_arms = 50
			if not isinstance(n_suggestions, int) or n_suggestions <= 0:
				n_suggestions = 3
			
			self.logger.info(f"Initializing algorithm with n_arms={n_arms}, n_suggestions={n_suggestions}")
			
			if self.algorithm_type == "random":
				self.current_algorithm = RandomMAB(n_arms, n_suggestions)
			elif self.algorithm_type == "epsilon_greedy":
				epsilon = self.algorithm_params.get('epsilon', 0.1)
				self.current_algorithm = EpsilonGreedyMAB(n_arms, n_suggestions, epsilon=epsilon)
			elif self.algorithm_type == "ucb":
				c = self.algorithm_params.get('c', 2.0)
				self.current_algorithm = UpperConfidenceBoundMAB(n_arms, n_suggestions, c=c)
			elif self.algorithm_type == "thompson":
				self.current_algorithm = ThompsonSamplingMAB(n_arms, n_suggestions)
			elif self.algorithm_type == "contextual":
				epsilon = self.algorithm_params.get('epsilon', 0.2)
				self.current_algorithm = ContextualMAB(n_arms, n_suggestions, epsilon=epsilon)
			else:
				# Default to epsilon greedy
				self.current_algorithm = EpsilonGreedyMAB(n_arms, n_suggestions, epsilon=0.1)
			
			self.logger.info(f"Initialized {self.algorithm_type} algorithm", extra={
				'algorithm_class': self.current_algorithm.__class__.__name__,
				'n_arms': n_arms,
				'n_suggestions': n_suggestions
			})
			
		except Exception as e:
			self.logger.error(f"Error initializing algorithm: {str(e)}")
			# Fallback to random algorithm with safe defaults
			try:
				self.current_algorithm = RandomMAB(50, 3)
				self.logger.info("Fallback to RandomMAB algorithm successful")
			except Exception as fallback_error:
				self.logger.error(f"Fallback algorithm also failed: {str(fallback_error)}")
				# Create a minimal algorithm as last resort
				from rl_recommender.mabalgorithms import RandomMAB
				self.current_algorithm = RandomMAB(10, 1)
	
	def set_algorithm(self, algorithm_type, algorithm_params=None):
		"""Set a new algorithm for the simulation."""
		self.algorithm_type = algorithm_type
		self.algorithm_params = algorithm_params or {}
		self._initialize_algorithm()
		
		return self.current_algorithm

	def run_experiment(self, algorithms, n_users=10, n_steps=20, random_runs=1, verbose=True):
		"""
		Run MAB experiment comparing different algorithms.

		Args:
			algorithms (dict): Dictionary of algorithm_name -> algorithm_instance
			n_users (int): Number of users to simulate
			n_steps (int): Number of recommendations per user
			random_runs (int): Number of random runs for each algorithm
			verbose (bool): Print progress
		"""
		start_time = time.time()
		
		# Log experiment start
		self.logger.log_simulation_start(
			algorithms=list(algorithms.keys()),
			n_users=n_users,
			n_steps=n_steps,
			random_runs=random_runs,
			verbose=verbose
		)
		
		results = {}

		for alg_name, algorithm in algorithms.items():
			if verbose:
				self.logger.info(f"Running {alg_name} algorithm...")

			total_rewards = []
			cumulative_rewards = []
			arm_selections = []
			user_rewards = []
			event_types_all = []  # Collect all event types

			cumulative_reward = 0

			if isinstance(algorithm, RandomMAB):
				# For RandomMAB, run multiple times and aggregate results
				total_rewards_all_runs = []
				cumulative_rewards_all_runs = []
				arm_selections_all_runs = []
				user_rewards_all_runs = []
				event_types_all_runs = []

				for run in range(random_runs):
					algorithm.reset()
					total_rewards_run = []
					cumulative_rewards_run = []
					arm_selections_run = []
					user_rewards_run = []
					event_types_run = []

					cumulative_reward_run = 0

					for user_idx in range(n_users):
						obs, _ = self.environment.reset()
						user_id = obs.get("user_id")
						user_total_reward = 0

						for step in range(n_steps):
							input_arms_idx = np.random.randint(0, self.environment.n_products, size=10)
							
							arms = algorithm.select_multiple_arms(input_arms_idx)

							# Convert numpy array to list to avoid indexing issues
							arms = arms.tolist() if hasattr(arms, 'tolist') else list(arms)

							obs, reward, _, _, info = self.environment.step(arms)
							event_types = info.get("event_types", [])

							algorithm.update(arms, [reward] * len(arms))

							total_rewards_run.append(reward)
							cumulative_reward_run += reward
							cumulative_rewards_run.append(cumulative_reward_run)
							arm_selections_run.append(arms)
							event_types_run.extend(event_types)
							user_total_reward += reward
					
						user_rewards_run.append(user_total_reward)

					total_rewards_all_runs.append(total_rewards_run)
					cumulative_rewards_all_runs.append(cumulative_rewards_run)
					arm_selections_all_runs.append(arm_selections_run)
					user_rewards_all_runs.append(user_rewards_run)
					event_types_all_runs.append(event_types_run)

				# Aggregate results across runs - always aggregate, even for single run
				# Average rewards across runs
				total_rewards = np.mean(total_rewards_all_runs, axis=0).tolist()
				cumulative_rewards = np.mean(cumulative_rewards_all_runs, axis=0).tolist()
				user_rewards = np.mean(user_rewards_all_runs, axis=0).tolist()
				
				# For arm selections, take the most common selection at each step
				arm_selections = []
				for step in range(len(arm_selections_all_runs[0])):
					step_selections = [run_selections[step] for run_selections in arm_selections_all_runs]
					# Flatten the selections for this step across all runs
					flat_selections = [item for sublist in step_selections for item in sublist]
					# Take the most common selection (mode)
					most_common = Counter(flat_selections).most_common(1)[0][0]
					arm_selections.append(most_common)
				
				# Aggregate event types across runs using weighted frequency
				event_types_all = []
				all_event_types = []
				for run_event_types in event_types_all_runs:
					all_event_types.append(run_event_types)

				for step in range(len(event_types_run)):
					# Sample event type based on frequency distribution
					event_types = [events[step] for events in all_event_types]
					event_types_all.append(np.random.choice(event_types))
				
				# Calculate cumulative reward as average of final rewards across runs
				cumulative_reward = np.mean([cumulative_rewards_all_runs[run][-1] for run in range(random_runs) if cumulative_rewards_all_runs[run]])

			else:
				algorithm.reset()
				for user_idx in range(n_users):
					obs, _ = self.environment.reset()
					user_id = obs.get("user_id")
					user_total_reward = 0

					for step in range(n_steps):
						if step>0:
							obs, reward, _, _, info = self.environment.step(arms)
							event_types = info.get("event_types", [])

							if isinstance(algorithm, ContextualMAB):
								algorithm.update(arms, [reward] * len(arms), context=obs)
							else:
								algorithm.update(arms, [reward] * len(arms))

						input_arms_idx = np.random.choice(self.environment.n_products, size=algorithm.n_arms, replace=False)

						if isinstance(algorithm, ContextualMAB):
							arms = algorithm.select_multiple_arms(obs, input_arms_idx)
						else:
							arms = algorithm.select_multiple_arms(input_arms_idx)

						# Convert numpy array to list to avoid indexing issues
						arms = arms.tolist() if hasattr(arms, 'tolist') else list(arms)

						obs, reward, _, _, info = self.environment.step(arms)
						event_types = info.get("event_types", [])

						if isinstance(algorithm, ContextualMAB):
							algorithm.update(arms, [reward] * len(arms), context=obs)
						else:
							algorithm.update(arms, [reward] * len(arms))

						total_rewards.append(reward)
						cumulative_reward += reward
						cumulative_rewards.append(cumulative_reward)
						arm_selections.append(arms)
						event_types_all.extend(event_types)
						user_total_reward += reward

					user_rewards.append(user_total_reward)

				if verbose and (user_idx + 1) % max(1, n_users // 5) == 0:
					self.logger.info(f"Completed {user_idx + 1}/{n_users} users, avg reward: {np.mean(user_rewards):.4f}")

			# Store results with event types
			results[alg_name] = {
				"algorithm": algorithm,
				"total_rewards": total_rewards,
				"cumulative_rewards": cumulative_rewards,
				"arm_selections": arm_selections,
				"user_rewards": user_rewards,
				"event_types": event_types_all,  # Include event types
				"final_total_reward": cumulative_reward,
				"average_reward": cumulative_reward / (n_steps * n_users),
				"arm_counts": algorithm.arm_counts.copy(),
				"arm_avg_rewards": np.divide(algorithm.arm_rewards, algorithm.arm_counts, out=np.zeros_like(algorithm.arm_rewards), where=algorithm.arm_counts != 0)
			}
		
		self.results = results
		
		# Log experiment completion
		duration = time.time() - start_time
		self.logger.log_simulation_complete(
			results=results,
			duration=duration,
			n_algorithms=len(algorithms),
			n_users=n_users,
			n_steps=n_steps
		)
		
		return results

	def invoke_agent(self, algorithm, event_df=pd.DataFrame(), input_arms_idx=None, obs=None, user_id=None, n_recommendations=None):
		"""
		Invoke the algorithm to get recommendations for a user.
		
		Args:
			algorithm: MAB algorithm instance
			event_df (pd.DataFrame): DataFrame containing recent user feedback events
			input_arms_idx: Available arms to choose from (if None, uses all arms)
			obs: Contextual information for contextual algorithms
			user_id: User identifier for logging
			n_recommendations: Number of recommendations to return (defaults to algorithm.n_suggestions)
		
		Returns:
			list: List of recommended item indices
		"""
		try:
			# Store current algorithm for updates
			self.current_algorithm = algorithm
			
			# Update algorithm with recent events
			if not event_df.empty:
				self.update_algorithm(event_df)
			
			# Determine number of recommendations
			if n_recommendations is None:
				n_recommendations = algorithm.n_suggestions
			
			# Generate input arms if not provided
			if input_arms_idx is None:
				# Use all available arms
				input_arms_idx = np.arange(algorithm.n_arms)
			
			# Ensure we don't request more recommendations than available arms
			n_recommendations = min(n_recommendations, len(input_arms_idx))
			
			# Get recommendations from algorithm
			if isinstance(algorithm, ContextualMAB):
				recommendations = algorithm.select_multiple_arms(obs, input_arms_idx)
			else:
				recommendations = algorithm.select_multiple_arms(input_arms_idx)
			
			# Ensure we return the right number of recommendations
			if len(recommendations) > n_recommendations:
				recommendations = recommendations[:n_recommendations]
			elif len(recommendations) < n_recommendations:
				# Pad with random selections if algorithm returns fewer
				remaining_arms = [arm for arm in input_arms_idx if arm not in recommendations]
				if remaining_arms:
					additional = np.random.choice(remaining_arms, size=n_recommendations - len(recommendations), replace=False)
					recommendations = np.concatenate([recommendations, additional])
			
			# Convert to list and ensure uniqueness
			recommendations = list(set(recommendations))[:n_recommendations]
			
			# Log the recommendation
			self.logger.info(f"Generated recommendations for user {user_id}: {recommendations}", extra={
				'user_id': user_id,
				'recommendations': recommendations,
				'algorithm': algorithm.__class__.__name__,
				'n_recommendations': len(recommendations),
				'input_arms_count': len(input_arms_idx)
			})
			
			return recommendations
			
		except Exception as e:
			self.logger.error(f"Error in invoke_agent: {str(e)}", extra={
				'user_id': user_id,
				'algorithm': algorithm.__class__.__name__ if algorithm else 'None',
				'error': str(e)
			})
			# Return random recommendations as fallback
			if input_arms_idx is not None and len(input_arms_idx) > 0:
				fallback = np.random.choice(input_arms_idx, size=min(n_recommendations or 3, len(input_arms_idx)), replace=False)
				return fallback.tolist()
			return []
	
	def update_algorithm(self, event_df):
		"""
		Update the algorithm based on event data.
		
		Args:
			event_df (pd.DataFrame): DataFrame containing user feedback events
				Expected columns: user_id, item_id, event_type, timestamp
		"""
		if event_df.empty:
			self.logger.info("No events to process for algorithm update")
			return
		
		# Process each event and update the algorithm
		for _, event in event_df.iterrows():
			try:
				user_id = event.get('user_id')
				item_id = event.get('item_id')
				event_type = event.get('event_type')
				timestamp = event.get('timestamp', pd.Timestamp.now())
				
				# Map event types to rewards based on environment configuration
				reward_mapping = {
					'view': 0.1,
					'cart': 0.33,
					'purchase': 1.0,
					'remove_from_cart': -0.5
				}
				
				reward = reward_mapping.get(event_type, 0.0)
				
				# Update the algorithm with the new reward
				if hasattr(self, 'current_algorithm') and self.current_algorithm is not None:
					# Convert item_id to arm index if needed
					arm_index = int(item_id) if isinstance(item_id, (int, np.integer)) else int(str(item_id))
					
					# Ensure arm_index is within bounds
					if 0 <= arm_index < self.current_algorithm.n_arms:
						self.current_algorithm.update([arm_index], [reward])
						
						self.logger.info(f"Updated algorithm with event: user={user_id}, item={item_id}, event={event_type}, reward={reward}", extra={
							'user_id': user_id,
							'item_id': item_id,
							'event_type': event_type,
							'reward': reward,
							'arm_index': arm_index,
							'timestamp': timestamp.isoformat() if hasattr(timestamp, 'isoformat') else str(timestamp)
						})
					else:
						self.logger.warning(f"Arm index {arm_index} out of bounds for algorithm with {self.current_algorithm.n_arms} arms")
				else:
					self.logger.warning("No current algorithm set for updates")
					
			except Exception as e:
				self.logger.error(f"Error updating algorithm with event {event}: {str(e)}", extra={
					'event': event.to_dict() if hasattr(event, 'to_dict') else str(event),
					'error': str(e)
				})

	def plot_results(self, figsize=(15, 10), save=False, filename="mab_experiment_results.png"):
		"""
		Plot comparison of MAB algorithms.
		"""
		if not MATPLOTLIB_AVAILABLE:
			self.logger.warning("Matplotlib not available; skipping plot_results")
			return

		if not self.results:
			self.logger.warning("No results to plot. Run an experiment first.")
			return

		fig, axes = plt.subplots(2, 2, figsize=figsize)

		#1. Cumulative Rewards
		ax = axes[0, 0]
		for alg_name, result in self.results.items():
			ax.plot(result["cumulative_rewards"], label=alg_name, alpha=0.8)
		ax.set_xlabel("Time Steps")
		ax.set_ylabel("Cumulative Reward")
		ax.set_title("Cumulative Rewards Over Time")
		ax.legend()
		ax.grid(True, alpha=0.3)

		#2. Average Reward per User
		ax = axes[0, 1]
		alg_names = list(self.results.keys())
		avg_rewards = [self.results[name]["average_reward"] for name in alg_names]
		bars = ax.bar(alg_names, avg_rewards, alpha=0.7)
		ax.set_xlabel("Algorithms")
		ax.set_ylabel("Average Reward")
		ax.tick_params(axis='x', rotation=45)
		ax.set_title("Average Reward per Algorithm")

		# Add value labels on bars
		for bar, value in zip(bars, avg_rewards):
			ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.001, f"{value:.4f}", ha="center", va="bottom")

		#3. Arm Selection Distribution
		ax = axes[1, 0]
		n_arms = len(self.results[alg_names[0]]["arm_counts"])
		x = np.arange(n_arms)
		width = 0.8 / len(alg_names)

		for i, (alg_name, result) in enumerate(self.results.items()):
			ax.bar(x + i * width, result["arm_counts"], width=width, alpha=0.7, label=alg_name)

		ax.set_xlabel("Arm Index")
		ax.set_ylabel("Selection Count")
		ax.set_title("Arm Selection Frequency")
		ax.set_xticks(x + width * (len(alg_names) - 1) / 2)
		ax.set_xticklabels([f"{i}" for i in range(n_arms)])
		ax.legend()

		#4. Reward Distribution
		ax = axes[1, 1]
		user_rewards_data = [result["user_rewards"] for result in self.results.values()]
		ax.boxplot(user_rewards_data, tick_labels=alg_names)
		ax.set_title('Reward Distribution per User')
		ax.set_ylabel('Total Reward per User')
		ax.tick_params(axis='x', rotation=45)
		ax.grid(True, alpha=0.3)

		plt.tight_layout()

		if save:
			plt.savefig(filename)
		else:
			plt.show()

	def print_summary(self):
		"""
		Print summary statistics.
		"""

		if not self.results:
			self.logger.warning("No results to summarize. Run an experiment first.")
			return
		
		self.logger.info("\n" + "="*60)
		self.logger.info("MAB SIMULATION RESULTS SUMMARY")
		self.logger.info("=" * 60)
		
		# Sort by average reward
		sorted_results = sorted(self.results.items(), key=lambda x: x[1]['average_reward'], reverse=True)
		
		for rank, (alg_name, result) in enumerate(sorted_results, 1):
			self.logger.info(f"\n{rank}. {alg_name}")
			self.logger.info(f"   Average Reward: {result['average_reward']:.4f}")
			self.logger.info(f"   Total Reward: {result['final_total_reward']:.2f}")
			self.logger.info(f"   Std Dev (per user): {np.std(result['user_rewards']):.4f}")
			
			# Most selected arms
			top_arms = np.argsort(result['arm_counts'])[-3:][::-1]
			self.logger.info(f"   Top 3 Arms: {top_arms} (counts: {result['arm_counts'][top_arms]})")
		
		self.logger.info("\n" + "=" * 60)


def create_all_algorithms(n_arms, n_suggestions):
	"""Create all available algorithms for testing."""
	
	# Create base algorithms for ensembles
	base_algorithms = [
		RandomMAB(n_arms, n_suggestions),
		EpsilonGreedyMAB(n_arms, n_suggestions, epsilon=0.1),
		EpsilonGreedyMAB(n_arms, n_suggestions, epsilon=0.3),
		UpperConfidenceBoundMAB(n_arms, n_suggestions, c=2.0),
		ThompsonSamplingMAB(n_arms, n_suggestions)
	]
	
	algorithms = {
		# Individual algorithms
		"Random": RandomMAB(n_arms, n_suggestions),
		"UCB (c=2.0)": UpperConfidenceBoundMAB(n_arms, n_suggestions, c=2.0),
		"Thompson Sampling": ThompsonSamplingMAB(n_arms, n_suggestions),
		"Contextual MAB": ContextualMAB(n_arms, n_suggestions, epsilon=0.2),
		"Epsilon-Greedy (ε=0.1)": EpsilonGreedyMAB(n_arms, n_suggestions, epsilon=0.1),
		"Epsilon-Greedy (ε=0.3)": EpsilonGreedyMAB(n_arms, n_suggestions, epsilon=0.3),
		"Epsilon-Greedy (ε=0.5)": EpsilonGreedyMAB(n_arms, n_suggestions, epsilon=0.5),
		"Epsilon-Greedy (ε=0.7)": EpsilonGreedyMAB(n_arms, n_suggestions, epsilon=0.7),
		"Epsilon-Greedy (ε=0.9)": EpsilonGreedyMAB(n_arms, n_suggestions, epsilon=0.9),
		# Ensemble algorithms
		"Voting Ensemble (Majority)": VotingEnsembleMAB(n_arms, n_suggestions, algorithms=base_algorithms.copy(), voting_method='majority'),
		"Voting Ensemble (Weighted)": VotingEnsembleMAB(n_arms, n_suggestions, algorithms=base_algorithms.copy(), voting_method='weighted'),
		"Voting Ensemble (Ranked)": VotingEnsembleMAB(n_arms, n_suggestions, algorithms=base_algorithms.copy(), voting_method='ranked'),
		"Weighted Ensemble": WeightedEnsembleMAB(n_arms, n_suggestions, algorithms=base_algorithms.copy(), weight_update_rate=0.01),
		"Dynamic Ensemble": DynamicEnsembleMAB(n_arms, n_suggestions, algorithms=base_algorithms.copy(), window_size=20, switch_threshold=0.05),
		"Expert Ensemble": ExpertEnsembleMAB(n_arms, n_suggestions, algorithms=base_algorithms.copy(), learning_rate=0.1)
	}

	return algorithms

def run_comprehensive_test():
	"""Run comprehensive test of all algorithms."""
	
	# Configuration
	n_arms = 20
	n_suggestions = 3
	n_users = 5
	n_steps = 100
	
	# Create environment
	mab_env = gym.make("gymnasium_env/RecommendationMAB", 
					  data="dataset/filtered_data.csv", 
					  max_arms=n_arms, 
					  n_suggestions=n_suggestions, 
					  seed=42)
	
	# Create all algorithms
	algorithms = create_all_algorithms(n_arms, n_suggestions)
	
	# Run simulation
	simulation = MABSimulation(mab_env)
	results = simulation.run_experiment(algorithms, n_users=n_users, n_steps=n_steps, verbose=True)
	
	# Print summary
	simulation.print_summary()
	
	# Plot results
	simulation.plot_results(figsize=(16, 12), save=True, filename="comprehensive_mab_results.png")

def run_visualization_test():
	"""Run test with visualization enabled."""
	
	# Configuration
	n_arms = 15
	n_suggestions = 4
	n_users = 1
	n_steps = 20
	
	# Create environment with visualization
	mab_env = gym.make("gymnasium_env/RecommendationMAB", 
					  data="dataset/filtered_data.csv", 
					  max_arms=n_arms, 
					  n_suggestions=n_suggestions, 
					  seed=42)
	
	# Enable visualization
	mab_env.unwrapped.set_render_mode("human")
	
	# Create algorithms (subset for visualization)
	algorithms = {
		"Random": RandomMAB(n_arms, n_suggestions),
		"Epsilon-Greedy (ε=0.2)": EpsilonGreedyMAB(n_arms, n_suggestions, epsilon=0.2),
		"UCB": UpperConfidenceBoundMAB(n_arms, n_suggestions, c=2.0),
		"Thompson Sampling": ThompsonSamplingMAB(n_arms, n_suggestions),
		"Voting Ensemble": VotingEnsembleMAB(n_arms, n_suggestions, 
										   algorithms=[RandomMAB(n_arms, n_suggestions),
													 EpsilonGreedyMAB(n_arms, n_suggestions, epsilon=0.1),
													 UpperConfidenceBoundMAB(n_arms, n_suggestions, c=2.0)],
										   voting_method='majority')
	}
	
	# Run simulation
	simulation = MABSimulation(mab_env)
	results = simulation.run_experiment(algorithms, n_users=n_users, n_steps=n_steps, verbose=True)
	
	# Print summary
	simulation.print_summary()

def run_quick_test():
	"""Run quick test of all algorithms."""
	
	# Configuration
	n_arms = 10
	n_suggestions = 2
	n_users = 3
	n_steps = 50
	
	# Create environment
	mab_env = gym.make("gymnasium_env/RecommendationMAB", 
					  data="dataset/filtered_data.csv", 
					  max_arms=n_arms, 
					  n_suggestions=n_suggestions, 
					  seed=42)
	
	# Create all algorithms
	algorithms = create_all_algorithms(n_arms, n_suggestions)
	
	# Run simulation
	simulation = MABSimulation(mab_env)
	results = simulation.run_experiment(algorithms, n_users=n_users, n_steps=n_steps, verbose=False)
	
	# Print summary
	simulation.print_summary()
	
	# Plot results
	simulation.plot_results(figsize=(16, 12), save=True, filename="quick_mab_results.png")

if __name__ == "__main__":
	# Run comprehensive test by default
	run_comprehensive_test()

# update & invoke