Illustrates the core idea of training separate encoders for image and text and learning a joint embedding space using a contrastive loss.
import torch
import torch.nn as nn
import torch.nn.functional as F
# Assume torchvision models for image encoder and transformers for text encoder
from torchvision.models import resnet50
from transformers import BertModel, BertTokenizer
class ImageEncoder(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.model = resnet50(pretrained=True)
# Replace the final fully connected layer to output desired embedding dimension
self.model.fc = nn.Linear(self.model.fc.in_features, embedding_dim)
def forward(self, images):
return self.model(images)
class TextEncoder(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.model = BertModel.from_pretrained('bert-base-uncased')
# Add a linear layer to project BERT's [CLS] token output to the embedding dimension
self.fc = nn.Linear(self.model.config.hidden_size, embedding_dim)
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def forward(self, texts): # texts is a list of strings
inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
# Move inputs to the same device as the model
inputs = {k: v.to(self.fc.weight.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
# Use the [CLS] token's representation
cls_representation = outputs.last_hidden_state[:, 0, :]
return self.fc(cls_representation)
class CLIPLikeModel(nn.Module):
def __init__(self, image_embedding_dim, text_embedding_dim, shared_embedding_dim):
super().__init__()
self.image_encoder = ImageEncoder(image_embedding_dim)
self.text_encoder = TextEncoder(text_embedding_dim)
# Projection heads to map to a shared embedding space (optional but common)
# For simplicity, let's assume image_embedding_dim and text_embedding_dim are already the shared_embedding_dim
# Or, add projection layers:
# self.image_projection = nn.Linear(image_embedding_dim, shared_embedding_dim)
# self.text_projection = nn.Linear(text_embedding_dim, shared_embedding_dim)
# Logit scale parameter (learnable)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, images, texts):
image_features = self.image_encoder(images)
text_features = self.text_encoder(texts)
# Normalize features
image_features = F.normalize(image_features, p=2, dim=-1)
text_features = F.normalize(text_features, p=2, dim=-1)
# Calculate cosine similarity (logits)
# Higher logit_scale makes the distribution sharper
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text
# Conceptual Training Snippet
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = CLIPLikeModel(image_embedding_dim=512, text_embedding_dim=512, shared_embedding_dim=512).to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# N = images.shape[0] # Batch size
# # Create ground truth labels for contrastive loss
# # These are diagonal matrices indicating correct pairs
# labels = torch.arange(N).to(device)
# logits_per_image, logits_per_text = model(images.to(device), texts_list)
# loss_i = F.cross_entropy(logits_per_image, labels)
# loss_t = F.cross_entropy(logits_per_text, labels)
# total_loss = (loss_i + loss_t) / 2.0
# optimizer.zero_grad()
# total_loss.backward()
# optimizer.step()