import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.colors import Normalize
import torch
import random
import colorsys

def set_seed(seed):
    """Set random seeds for reproducibility"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def create_graph_from_connections(connections, num_nodes=12):
    """
    Create a graph from a list of connections.
    
    Args:
        connections: List of tuples representing edges
        num_nodes: Total number of nodes in the graph
        
    Returns:
        G: NetworkX graph
        adjacency_matrix: Numpy array of the adjacency matrix
    """
    adjacency_matrix = np.zeros((num_nodes, num_nodes), dtype=int)
    for i, j in connections:
        adjacency_matrix[i, j] = 1
        adjacency_matrix[j, i] = 1
    
    G = nx.from_numpy_array(adjacency_matrix)
    return G, adjacency_matrix

def get_layout_positions(layout_type='custom'):
    """
    Get node positions for the graph layout.
    
    Args:
        layout_type: Type of layout ('custom' for the molecular structure layout)
        
    Returns:
        pos: Dictionary of node positions
    """
    if layout_type == 'custom':
        # Define positions manually to match the image
        pos = {
            0: (0, 0),      # center-right
            1: (1, 1),      # top-right
            2: (2, 0),      # right
            3: (2, -1),     # lower-right
            4: (1, -2),     # bottom-right
            5: (0, -1),     # center-bottom
            6: (-1, 0),     # center-left
            7: (-2, 1),     # top-left
            8: (-3, 0),     # left
            9: (-3, -1),    # lower-left
            10: (-2, -2),   # bottom-left
            11: (-1, -1)    # center
        }
    else:
        # Other layout types could be implemented here
        pos = None
    
    return pos

# def normalize_embeddings(embeddings):
#     """
#     Normalize embeddings to [0,1] range for RGB values.
    
#     Args:
#         embeddings: Node embeddings array
        
#     Returns:
#         rgb_normalized: Normalized embeddings for RGB colors
#         node_colors: List of RGB tuples for each node
#     """
#     rgb_normalized = np.zeros_like(embeddings)
#     embeddings[:,0] += 0.02 # Add a tinge of blue to avoid boring colors
#     for i in range(min(3, embeddings.shape[1])):
#         norm = Normalize(embeddings[:, i].min(), embeddings[:, i].max())
#         rgb_normalized[:, i] = norm(embeddings[:, i])
    
#     # Create RGB colors for each node
#     node_colors = []
#     for i in range(embeddings.shape[0]):
#         node_colors.append((rgb_normalized[i, 0], rgb_normalized[i, 1], rgb_normalized[i, 2]))
    
#     return rgb_normalized, node_colors
def normalize_embeddings(embeddings):
    """
    Create distinct colors from embeddings using improved methods.
    
    Args:
        embeddings: Node embeddings array
        
    Returns:
        rgb_normalized: Normalized embeddings for RGB colors
        node_colors: List of RGB tuples for each node
    """
    num_nodes = embeddings.shape[0]
    
    # Method 1: Using HSV color space for better perceptual distinction
    # This works well when we want colors that are visually distinct
    node_colors = []
    
    if embeddings.shape[1] >= 3:
        # If we have 3+ dimensions, use them directly with enhancement
        # Normalize each dimension to [0,1]
        normalized = np.zeros_like(embeddings[:, :3])
        for i in range(3):
            if embeddings[:, i].max() == embeddings[:, i].min():
                normalized[:, i] = 0.5  # Default if all values are identical
            else:
                # Enhanced normalization with contrast boosting
                norm = Normalize(embeddings[:, i].min(), embeddings[:, i].max())
                values = norm(embeddings[:, i])
                # Boost contrast by expanding the range
                values = np.clip((values - 0.5) * 1.2 + 0.5, 0, 1)
                normalized[:, i] = values
        
        # Convert to perceptually more uniform colorspace (HSV)
        hsv_colors = np.zeros_like(normalized)
        
        # Map first dimension to hue (0-1), spread across color wheel
        hsv_colors[:, 0] = normalized[:, 0]
        
        # Map second dimension to saturation (0.6-1.0) - avoid grayish colors
        hsv_colors[:, 1] = 0.6 + 0.4 * normalized[:, 1]
        
        # Map third dimension to value (0.7-1.0) - avoid dark colors
        hsv_colors[:, 2] = 0.7 + 0.3 * normalized[:, 2]
        
        # Convert HSV to RGB
        for i in range(num_nodes):
            node_colors.append(colorsys.hsv_to_rgb(hsv_colors[i, 0], 
                                                  hsv_colors[i, 1], 
                                                  hsv_colors[i, 2]))
    else:
        # If we have fewer than 3 dimensions, use evenly spaced colors from color wheel
        for i in range(num_nodes):
            hue = i / num_nodes  # Evenly spaced hues
            
            # Use embedding values to adjust saturation/value if available
            sat = 0.8
            val = 0.9
            if embeddings.shape[1] >= 1:
                # Normalize the first dimension
                if embeddings[:, 0].max() != embeddings[:, 0].min():
                    norm_val = Normalize(embeddings[:, 0].min(), embeddings[:, 0].max())(embeddings[i, 0])
                    # Use it to vary saturation (0.7-1.0)
                    sat = 0.7 + 0.3 * norm_val
            if embeddings.shape[1] >= 2:
                # Normalize the second dimension
                if embeddings[:, 1].max() != embeddings[:, 1].min():
                    norm_val = Normalize(embeddings[:, 1].min(), embeddings[:, 1].max())(embeddings[i, 1])
                    # Use it to vary value (0.7-1.0)
                    val = 0.7 + 0.3 * norm_val
            
            # Convert to RGB
            node_colors.append(colorsys.hsv_to_rgb(hue, sat, val))
    
    # Also return a normalized version for individual dimension visualization
    rgb_normalized = np.zeros((num_nodes, 3))
    for i in range(min(3, embeddings.shape[1])):
        if embeddings[:, i].max() == embeddings[:, i].min():
            rgb_normalized[:, i] = 0.5
        else:
            norm = Normalize(embeddings[:, i].min(), embeddings[:, i].max())
            rgb_normalized[:, i] = norm(embeddings[:, i])
    
    return rgb_normalized, node_colors


def plot_graph_with_embeddings(G, embeddings, pos=None, title="Graph with Node Embeddings", 
                              node_size=3000, font_size=22, save_path=None):
    """
    Plot graph with nodes colored by their embeddings.
    
    Args:
        G: NetworkX graph
        embeddings: Node embeddings array
        pos: Dictionary of node positions
        title: Plot title
        node_size: Size of nodes
        font_size: Size of node labels
        save_path: Path to save the figure
    """
    if pos is None:
        pos = get_layout_positions('custom')
    
    rgb_normalized, node_colors = normalize_embeddings(embeddings)
    
    # Plot the graph
    plt.figure(figsize=(12, 10))

    # Draw edges
    nx.draw_networkx_edges(G, pos, width=2.5, alpha=0.8, edge_color='gray')

    # Draw nodes with embedding-based colors
    nx.draw_networkx_nodes(G, pos, 
                          node_color=node_colors, 
                          node_size=node_size,  # Increased from 1500 to 3000
                          alpha=1.0)

    # Draw labels
    nx.draw_networkx_labels(G, pos, font_size=font_size, font_weight='bold')

    plt.title(title, fontsize=18)
    plt.axis('off')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    
    return rgb_normalized

def plot_embedding_dimensions(G, embeddings, pos=None, node_size=1200, font_size=18, save_path=None):
    """
    Create separate visualizations for each embedding dimension.
    
    Args:
        G: NetworkX graph
        embeddings: Node embeddings array
        pos: Dictionary of node positions
        node_size: Size of nodes
        font_size: Size of node labels
        save_path: Path to save the figure
    """
    if pos is None:
        pos = get_layout_positions('custom')
    
    rgb_normalized, _ = normalize_embeddings(embeddings)
    
    n_dims = min(3, embeddings.shape[1])
    
    # Create separate visualizations
    plt.figure(figsize=(15, 5))
    axes = [plt.subplot(1, n_dims, i+1) for i in range(n_dims)]

    # Create color maps for individual embedding dimensions
    cmap = plt.cm.viridis
    for i, ax in enumerate(axes):
        nx.draw_networkx_edges(G, pos, width=1.5, alpha=0.6, edge_color='gray', ax=ax)
        nodes = nx.draw_networkx_nodes(G, pos, 
                                    node_color=rgb_normalized[:, i], 
                                    node_size=node_size,  # This was already 4200
                                    cmap=cmap,
                                    ax=ax)
        nx.draw_networkx_labels(G, pos, font_size=font_size, font_weight='bold', ax=ax)
        ax.set_title(f"Embedding Dimension {i+1}", fontsize=14)
        ax.set_axis_off()
        plt.colorbar(nodes, ax=ax, shrink=0.7)

    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()