Illustrates a simplified model for learning audio-visual correspondence using a contrastive or binary classification approach.
import torch
import torch.nn as nn
import torch.nn.functional as F
# Assume pre-trained or custom encoders for audio and video
class AudioEncoder(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
# Example: A simple CNN for spectrograms or a more complex model like VGGish
self.conv_stack = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
# Placeholder for actual output size calculation based on input spectrogram size
self.fc = nn.Linear(128 * 32 * 32, embedding_dim) # Adjust 32*32 based on actual output
def forward(self, audio_spectrogram):
# audio_spectrogram: (batch, 1, freq_bins, time_frames)
x = self.conv_stack(audio_spectrogram)
x = x.view(x.size(0), -1) # Flatten
return self.fc(x)
class VideoEncoder(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
# Example: Using a pre-trained ResNet and modifying the final layer
# from torchvision.models import resnet18
# self.visual_model = resnet18(pretrained=True)
# self.visual_model.fc = nn.Linear(self.visual_model.fc.in_features, embedding_dim)
# For simplicity, a dummy CNN:
self.conv_stack = nn.Sequential(
nn.Conv3d(3, 16, kernel_size=3, stride=1, padding=1), # (batch, 3, T, H, W)
nn.ReLU(),
nn.MaxPool3d((1, 2, 2)),
nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool3d((1, 1, 1)) # Global pooling
)
self.fc = nn.Linear(32, embedding_dim)
def forward(self, video_frames):
# video_frames: (batch, channels, num_frames, height, width)
x = self.conv_stack(video_frames)
x = x.view(x.size(0), -1) # Flatten
return self.fc(x)
class AudioVisualCorrespondenceModel(nn.Module):
def __init__(self, audio_embedding_dim, video_embedding_dim, projection_dim):
super().__init__()
self.audio_encoder = AudioEncoder(audio_embedding_dim)
self.video_encoder = VideoEncoder(video_embedding_dim)
# Projection heads to a shared space (common for contrastive learning)
self.audio_projection = nn.Linear(audio_embedding_dim, projection_dim)
self.video_projection = nn.Linear(video_embedding_dim, projection_dim)
# For binary classification of correspondence
# self.classifier = nn.Linear(audio_embedding_dim + video_embedding_dim, 1)
# Or, if using projected features: self.classifier = nn.Linear(projection_dim * 2, 1)
def forward(self, audio_input, video_input):
audio_features = self.audio_encoder(audio_input)
video_features = self.video_encoder(video_input)
# Project to shared embedding space
audio_projected = self.audio_projection(audio_features)
video_projected = self.video_projection(video_features)
# Normalize for contrastive loss (InfoNCE)
audio_projected = F.normalize(audio_projected, p=2, dim=-1)
video_projected = F.normalize(video_projected, p=2, dim=-1)
return audio_projected, video_projected
# For binary classification:
# combined_features = torch.cat((audio_features, video_features), dim=-1)
# correspondence_logit = self.classifier(combined_features)
# return torch.sigmoid(correspondence_logit)
# Conceptual Training Snippet for Contrastive Learning (InfoNCE-like)
# model = AudioVisualCorrespondenceModel(audio_embedding_dim=256, video_embedding_dim=256, projection_dim=128)
# optimizer = torch.optim.Adam(model.parameters())
# logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
# # Assume audio_batch and video_batch are batches of corresponding (positive) pairs
# audio_embeds, video_embeds = model(audio_batch, video_batch)
# # audio_embeds: (N, projection_dim), video_embeds: (N, projection_dim)
# # Calculate logits: (N, N) matrix
# scaled_logit_scale = logit_scale.exp()
# logits_av = scaled_logit_scale * audio_embeds @ video_embeds.t()
# logits_va = scaled_logit_scale * video_embeds @ audio_embeds.t()
# N = audio_embeds.size(0)
# labels = torch.arange(N) # Ground truth: diagonal elements are positives
# loss_a = F.cross_entropy(logits_av, labels)
# loss_v = F.cross_entropy(logits_va, labels)
# total_loss = (loss_a + loss_v) / 2.0
# optimizer.zero_grad()
# total_loss.backward()
# optimizer.step()