Empathic-Insight-Voice-Small

18
license:cc-by-4.0
by
laion
Other
OTHER
New
0 downloads
Early-stage
Edge AI:
Mobile
Laptop
Server
Unknown
Mobile
Laptop
Server
Quick Summary

Empathic-Insight-Voice-Small [](https://colab.

Code Examples

--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")
--- Configuration (should match Cell 2 of the Colab) ---pythontransformers
import torch
import torch.nn as nn
import librosa
import numpy as np
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from huggingface_hub import snapshot_download # For downloading MLP models
import gc # For memory management

# --- Configuration (should match Cell 2 of the Colab) ---
SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 30.0
WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1"
HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" # Or -Large if using those
LOCAL_MLP_MODELS_DOWNLOAD_DIR = Path("./empathic_insight_voice_small_models_downloaded")

WHISPER_SEQ_LEN = 1500
WHISPER_EMBED_DIM = 768
PROJECTION_DIM_FOR_FULL_EMBED = 64 # For 'Small' models
MLP_HIDDEN_DIMS = [64, 32, 16]    # For 'Small' models
MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] # For 'Small' models

# Mapping from .pth file name parts to human-readable dimension keys
# (Abridged, full map in Colab Cell 2)

FILENAME_PART_TO_TARGET_KEY_MAP: Dict[str, str] = {
    "Affection": "Affection", "Age": "Age", "Amusement": "Amusement", "Anger": "Anger",
    "Arousal": "Arousal", "Astonishment_Surprise": "Astonishment/Surprise",
    "Authenticity": "Authenticity", "Awe": "Awe", "Background_Noise": "Background_Noise",
    "Bitterness": "Bitterness", "Concentration": "Concentration",
    "Confident_vs._Hesitant": "Confident_vs._Hesitant", "Confusion": "Confusion",
    "Contemplation": "Contemplation", "Contempt": "Contempt", "Contentment": "Contentment",
    "Disappointment": "Disappointment", "Disgust": "Disgust", "Distress": "Distress",
    "Doubt": "Doubt", "Elation": "Elation", "Embarrassment": "Embarrassment",
    "Emotional_Numbness": "Emotional Numbness", "Fatigue_Exhaustion": "Fatigue/Exhaustion",
    "Fear": "Fear", "Gender": "Gender", "Helplessness": "Helplessness",
    "High-Pitched_vs._Low-Pitched": "High-Pitched_vs._Low-Pitched",
    "Hope_Enthusiasm_Optimism": "Hope/Enthusiasm/Optimism",
    "Impatience_and_Irritability": "Impatience and Irritability",
    "Infatuation": "Infatuation", "Interest": "Interest",
    "Intoxication_Altered_States_of_Consciousness": "Intoxication/Altered States of Consciousness",
    "Jealousy_&_Envy": "Jealousy / Envy", "Longing": "Longing",
    "Malevolence_Malice": "Malevolence/Malice",
    "Monotone_vs._Expressive": "Monotone_vs._Expressive", "Pain": "Pain",
    "Pleasure_Ecstasy": "Pleasure/Ecstasy", "Pride": "Pride",
    "Recording_Quality": "Recording_Quality", "Relief": "Relief", "Sadness": "Sadness",
    "Serious_vs._Humorous": "Serious_vs._Humorous", "Sexual_Lust": "Sexual Lust",
    "Shame": "Shame", "Soft_vs._Harsh": "Soft_vs._Harsh", "Sourness": "Sourness",
    "Submissive_vs._Dominant": "Submissive_vs._Dominant", "Teasing": "Teasing",
    "Thankfulness_Gratitude": "Thankfulness/Gratitude", "Triumph": "Triumph",
    "Valence": "Valence",
    "Vulnerable_vs._Emotionally_Detached": "Vulnerable_vs._Emotionally_Detached",
    "Warm_vs._Cold": "Warm_vs._Cold"
}

TARGET_EMOTION_KEYS_FOR_REPORT: List[str] = [
    "Amusement", "Elation", "Pleasure/Ecstasy", "Contentment", "Thankfulness/Gratitude",
    "Affection", "Infatuation", "Hope/Enthusiasm/Optimism", "Triumph", "Pride",
    "Interest", "Awe", "Astonishment/Surprise", "Concentration", "Contemplation",
    "Relief", "Longing", "Teasing", "Impatience and Irritability",
    "Sexual Lust", "Doubt", "Fear", "Distress", "Confusion", "Embarrassment", "Shame",
    "Disappointment", "Sadness", "Bitterness", "Contempt", "Disgust", "Anger",
    "Malevolence/Malice", "Sourness", "Pain", "Helplessness", "Fatigue/Exhaustion",
    "Emotional Numbness", "Intoxication/Altered States of Consciousness", "Jealousy / Envy"
]

# --- MLP Model Definition (from Colab Cell 2) ---
class FullEmbeddingMLP(nn.Module):
    def __init__(self, seq_len, embed_dim, projection_dim, mlp_hidden_dims, mlp_dropout_rates):
        super().__init__()
        if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1:
            raise ValueError("Dropout rates length error.")
        self.flatten = nn.Flatten()
        self.proj = nn.Linear(seq_len * embed_dim, projection_dim)
        layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])]
        current_dim = projection_dim
        for i, h_dim in enumerate(mlp_hidden_dims):
            layers.extend([nn.Linear(current_dim, h_dim), nn.ReLU(), nn.Dropout(mlp_dropout_rates[i+1])])
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x):
        if x.ndim == 4 and x.shape[1] == 1: x = x.squeeze(1)
        return self.mlp(self.proj(self.flatten(x)))

# --- Global Model Placeholders ---
whisper_model_global = None
whisper_processor_global = None
all_mlp_model_paths_dict = {} # To be populated
WHISPER_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MLP_DEVICE = torch.device("cpu") # As per USE_CPU_OFFLOADING_FOR_MLPS in Colab

def initialize_models():
    global whisper_model_global, whisper_processor_global, all_mlp_model_paths_dict

    print(f"Whisper will run on: {WHISPER_DEVICE}")
    print(f"MLPs will run on: {MLP_DEVICE}")

    # Load Whisper
    if whisper_model_global is None:
        print(f"Loading Whisper model '{WHISPER_MODEL_ID}'...")
        whisper_processor_global = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID)
        whisper_model_global = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_ID).to(WHISPER_DEVICE).eval()
        print("Whisper model loaded.")

    # Download and map MLPs (paths only, models loaded on-demand)
    if not all_mlp_model_paths_dict:
        print(f"Downloading MLP checkpoints from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DOWNLOAD_DIR}...")
        LOCAL_MLP_MODELS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
        snapshot_download(
            repo_id=HF_MLP_REPO_ID,
            local_dir=LOCAL_MLP_MODELS_DOWNLOAD_DIR,
            local_dir_use_symlinks=False,
            allow_patterns=["*.pth"],
            repo_type="model"
        )
        print("MLP checkpoints downloaded.")

        # Map .pth files to target keys (simplified from Colab Cell 2)
        for pth_file in LOCAL_MLP_MODELS_DOWNLOAD_DIR.glob("model_*_best.pth"):
            try:
                filename_part = pth_file.name.split("model_")[1].split("_best.pth")[0]
                if filename_part in FILENAME_PART_TO_TARGET_KEY_MAP:
                    target_key = FILENAME_PART_TO_TARGET_KEY_MAP[filename_part]
                    all_mlp_model_paths_dict[target_key] = pth_file
            except IndexError:
                print(f"Warning: Could not parse filename part from {pth_file.name}")
        print(f"Mapped {len(all_mlp_model_paths_dict)} MLP model paths.")
        if not all_mlp_model_paths_dict:
             raise RuntimeError("No MLP model paths could be mapped. Check FILENAME_PART_TO_TARGET_KEY_MAP and downloaded files.")


@torch.no_grad()
def get_whisper_embedding(audio_waveform_np):
    if whisper_model_global is None or whisper_processor_global is None:
        raise RuntimeError("Whisper model not initialized. Call initialize_models() first.")

    input_features = whisper_processor_global(
        audio_waveform_np, sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_features.to(WHISPER_DEVICE).to(whisper_model_global.dtype)

    encoder_outputs = whisper_model_global.get_encoder()(input_features=input_features)
    embedding = encoder_outputs.last_hidden_state

    current_seq_len = embedding.shape[1]
    if current_seq_len < WHISPER_SEQ_LEN:
        padding = torch.zeros((1, WHISPER_SEQ_LEN - current_seq_len, WHISPER_EMBED_DIM),
                              device=WHISPER_DEVICE, dtype=embedding.dtype)
        embedding = torch.cat((embedding, padding), dim=1)
    elif current_seq_len > WHISPER_SEQ_LEN:
        embedding = embedding[:, :WHISPER_SEQ_LEN, :]
    return embedding

def load_single_mlp(model_path, target_key):
    # Simplified loading for example (Colab Cell 2 has more robust loading)
    # For this example, assumes USE_HALF_PRECISION_FOR_MLPS=False, USE_TORCH_COMPILE_FOR_MLPS=False
    print(f"  Loading MLP for '{target_key}'...")
    model_instance = FullEmbeddingMLP(
        WHISPER_SEQ_LEN, WHISPER_EMBED_DIM, PROJECTION_DIM_FOR_FULL_EMBED,
        MLP_HIDDEN_DIMS, MLP_DROPOUTS
    )
    state_dict = torch.load(model_path, map_location='cpu')
    # Handle potential '_orig_mod.' prefix if model was torch.compile'd during training
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model_instance.load_state_dict(state_dict)
    model_instance = model_instance.to(MLP_DEVICE).eval()
    return model_instance

@torch.no_grad()
def predict_with_mlp(embedding, mlp_model):
    embedding_for_mlp = embedding.to(MLP_DEVICE)
    # Ensure dtype matches (simplified)
    mlp_dtype = next(mlp_model.parameters()).dtype
    prediction = mlp_model(embedding_for_mlp.to(mlp_dtype))
    return prediction.item()

def process_audio_file(audio_file_path_str: str) -> Dict[str, float]:
    if not all_mlp_model_paths_dict:
        initialize_models() # Ensure models are ready

    print(f"Processing audio file: {audio_file_path_str}")
    try:
        waveform, sr = librosa.load(audio_file_path_str, sr=SAMPLING_RATE, mono=True)
        max_samples = int(MAX_AUDIO_SECONDS * SAMPLING_RATE)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
        print(f"Audio loaded. Duration: {len(waveform)/SAMPLING_RATE:.2f}s")
    except Exception as e:
        print(f"Error loading audio {audio_file_path_str}: {e}")
        return {}

    embedding = get_whisper_embedding(waveform)
    del waveform; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    all_scores: Dict[str, float] = {}
    for target_key, mlp_model_path in all_mlp_model_paths_dict.items():
        if target_key not in FILENAME_PART_TO_TARGET_KEY_MAP.values(): # Only process mapped keys
            continue

        current_mlp_model = load_single_mlp(mlp_model_path, target_key)
        if current_mlp_model:
            score = predict_with_mlp(embedding, current_mlp_model)
            all_scores[target_key] = score
            print(f"    {target_key}: {score:.4f}")
            del current_mlp_model # Unload after use
            gc.collect()
            if MLP_DEVICE.type == 'cuda': torch.cuda.empty_cache()
        else:
            all_scores[target_key] = float('nan')

    del embedding; gc.collect();
    if WHISPER_DEVICE.type == 'cuda': torch.cuda.empty_cache()

    # Optional: Calculate Softmax for the 40 primary emotions
    emotion_raw_scores = [all_scores.get(k, -float('inf')) for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores]
    if emotion_raw_scores:
        softmax_probs = torch.softmax(torch.tensor(emotion_raw_scores, dtype=torch.float32), dim=0)
        print("\nTop 3 Emotions (Softmax Probabilities):")
        # Create a dictionary of {emotion_key: softmax_prob}
        emotion_softmax_dict = {
            key: prob.item()
            for key, prob in zip(
                [k for k in TARGET_EMOTION_KEYS_FOR_REPORT if k in all_scores], # only keys that had scores
                softmax_probs
            )
        }
        sorted_emotions = sorted(emotion_softmax_dict.items(), key=lambda item: item[1], reverse=True)
        for i, (emotion, prob) in enumerate(sorted_emotions[:3]):
            print(f"  {i+1}. {emotion}: {prob:.4f} (Raw: {all_scores.get(emotion, float('nan')):.4f})")
    return all_scores

# --- Example Usage (Run this after defining functions and initializing models) ---
# Make sure to have an audio file (e.g., "sample.mp3") in your current directory or provide a full path.
# And ensure FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT are fully populated.
#
# initialize_models() # Call this once
#
# # Create a dummy sample.mp3 for testing if it doesn't exist
# if not Path("sample.mp3").exists():
#     print("Creating dummy sample.mp3 for testing...")
#     dummy_sr = 16000
#     dummy_duration = 5 # seconds
#     dummy_tone_freq = 440 # A4 note
#     t = np.linspace(0, dummy_duration, int(dummy_sr * dummy_duration), endpoint=False)
#     dummy_waveform = 0.5 * np.sin(2 * np.pi * dummy_tone_freq * t)
#     import soundfile as sf
#     sf.write("sample.mp3", dummy_waveform, dummy_sr)
#     print("Dummy sample.mp3 created.")
#
# if Path("sample.mp3").exists() and FILENAME_PART_TO_TARGET_KEY_MAP and TARGET_EMOTION_KEYS_FOR_REPORT:
#    results = process_audio_file("sample.mp3")
#    # print("\nFull Scores Dictionary:", results)
# else:
#    print("Skipping example usage: 'sample.mp3' not found or maps are not fully populated.")

Deploy This Model

Production-ready deployment in minutes

Together.ai

Instant API access to this model

Fastest API

Production-ready inference API. Start free, scale to millions.

Try Free API

Replicate

One-click model deployment

Easiest Setup

Run models in the cloud with simple API. No DevOps required.

Deploy Now

Disclosure: We may earn a commission from these partners. This helps keep LLMYourWay free.