Cross Modal Retrieval

Overview

Cross-modal Retrieval is the task of retrieving relevant information from one modality (e.g., images) using a query from a different modality (e.g., text), or vice-versa. It aims to bridge the semantic gap between different types of data by learning a common representation space where data points from different modalities but with similar semantic meaning are located close to each other.

For instance, a user might provide a textual description (e.g., "a red apple on a table") to search for matching images, or provide an image of a landscape to find similar scenes described in text documents.

Core Concepts

  • Key Objectives

    • Semantic Alignment: Learning a shared or aligned embedding space where semantically similar items from different modalities are close together.
    • Efficient Search: Enabling fast and accurate retrieval across large-scale multimodal datasets.
    • Modality Translation: Implicitly learning to translate concepts from one modality to another.
  • Common Modality Pairs

    • Text-to-Image Retrieval (and Image-to-Text)
    • Text-to-Video Retrieval (and Video-to-Text)
    • Audio-to-Image Retrieval (and Image-to-Audio)
    • Sketch-to-Image Retrieval
    • Speech-to-Text Retrieval (though often treated as ASR + text retrieval)
  • Challenges

    • Heterogeneity Gap (Modality Gap): Different modalities have vastly different statistical properties and structures, making it hard to compare them directly.
    • Semantic Gap: High-level semantic concepts may be represented very differently in low-level features of various modalities.
    • Embedding Space Learning: Designing effective shared or coordinated embedding spaces and appropriate similarity/distance metrics.
    • Scalability: Efficiently indexing and searching in high-dimensional embedding spaces for large datasets.
    • Evaluation: Defining good metrics (e.g., Recall@K, Mean Average Precision) and benchmarks that reflect real-world utility.
  • Shared Embedding Space Models (e.g., using Deep CCA, Siamese Networks, CLIP-like)

    The dominant approach is to learn a shared (or common) embedding space where features from different modalities can be directly compared.

    • Deep Canonical Correlation Analysis (DCCA): Learns non-linear transformations for two modalities such that their resulting representations are maximally correlated.
    • Siamese Networks / Triplet Networks: These networks are trained with pairs or triplets of data. For pairs, the network learns to pull semantically similar cross-modal pairs together and push dissimilar pairs apart. For triplets (e.g., an anchor from modality A, a positive sample from modality B, a negative sample from modality B), a triplet loss is used to ensure the anchor is closer to the positive sample than to the negative sample in the learned space.
    • Contrastive Learning (e.g., CLIP, VSE++): As seen in models like CLIP, contrastive losses are highly effective. Given a batch of paired data, they aim to maximize the similarity of correct pairs while minimizing similarity for all incorrect pairs. Visual Semantic Embeddings (VSE++) is another example focusing on image-text retrieval.
    • Adversarial Learning: Using adversarial training to encourage modality-invariant features in the shared space.

    Once the shared space is learned, retrieval is performed by embedding the query (from one modality) into this space and finding the nearest neighbors from the other modality using a distance metric like cosine similarity or Euclidean distance.

  • Cross-Attention Models / Transformers

    Transformer-based architectures with cross-attention mechanisms allow for fine-grained interaction between features from different modalities.

    For retrieval, a query from one modality can attend to elements of items from another modality to compute a relevance or similarity score directly, without necessarily projecting everything into a single static shared vector space for all items beforehand. This can be more computationally intensive at query time but allows for richer interactions.

    Models like ViLBERT or LXMERT, while often used for VQA, also have components or can be adapted for cross-modal retrieval tasks by scoring the match between image regions and text tokens.

  • Graph-based Methods

    Representing multimodal data as a heterogeneous graph where nodes are data instances (e.g., images, text documents) and edges represent intra-modal or inter-modal relationships. Graph neural networks (GNNs) can then be used to learn node embeddings for retrieval.

    These methods can effectively model complex relationships and leverage a small amount of labeled data by propagating information through the graph structure.

  • Common Metrics for Retrieval Tasks

    • Recall@K (R@K): The percentage of queries for which at least one correct item is found within the top K retrieved results. Commonly reported for K=1, 5, 10.
    • Precision@K (P@K): The proportion of retrieved items in the top K that are relevant. Less common as a primary metric in large-scale retrieval if only one ground truth exists.
    • Mean Average Precision (mAP): The mean of average precision scores for a set of queries. Average Precision rewards ranking relevant items higher.
    • Normalized Discounted Cumulative Gain (nDCG@K): A measure of ranking quality that assigns higher scores to relevant items appearing earlier in the ranked list, considering graded relevance.
    • Median Rank (MedR) / Mean Rank (MeanR): The median or mean rank of the first correctly retrieved item. Lower values are better.

Implementation

  • Conceptual Siamese Network for Image-Text Retrieval (PyTorch-like)

    Illustrates a Siamese network structure with a contrastive or triplet loss to learn aligned embeddings.
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    # Assume pre-defined ImageEncoder and TextEncoder similar to those in VLM or AVL examples
    class ImageEncoder(nn.Module): # Simplified placeholder
        def __init__(self, embedding_dim):
            super().__init__()
            self.fc = nn.Linear(2048, embedding_dim) # Assuming input from a ResNet-like backbone
        def forward(self, x): return self.fc(x)
    
    class TextEncoder(nn.Module): # Simplified placeholder
        def __init__(self, embedding_dim):
            super().__init__()
            self.fc = nn.Linear(768, embedding_dim) # Assuming input from a BERT-like backbone
        def forward(self, x): return self.fc(x)
    
    class CrossModalRetrievalNet(nn.Module):
        def __init__(self, image_feature_dim, text_feature_dim, shared_embedding_dim):
            super().__init__()
            self.image_encoder = ImageEncoder(image_feature_dim) # This should be a proper vision model
            self.text_encoder = TextEncoder(text_feature_dim)   # This should be a proper language model
            
            # Projection heads to map to the shared embedding space
            self.image_projection = nn.Linear(image_feature_dim, shared_embedding_dim)
            self.text_projection = nn.Linear(text_feature_dim, shared_embedding_dim)
    
        def encode_image(self, image_features_raw):
            # In a real scenario, image_features_raw would be processed by a full vision backbone
            projected = self.image_projection(self.image_encoder(image_features_raw))
            return F.normalize(projected, p=2, dim=-1)
    
        def encode_text(self, text_features_raw):
            # In a real scenario, text_features_raw would be processed by a full text model
            projected = self.text_projection(self.text_encoder(text_features_raw))
            return F.normalize(projected, p=2, dim=-1)
    
        def forward(self, image_features_raw, text_features_raw):
            image_embeddings = self.encode_image(image_features_raw)
            text_embeddings = self.encode_text(text_features_raw)
            return image_embeddings, text_embeddings
    
    # Conceptual Triplet Loss Training Snippet
    # model = CrossModalRetrievalNet(image_feature_dim=2048, text_feature_dim=768, shared_embedding_dim=256)
    # triplet_loss = nn.TripletMarginLoss(margin=0.2, p=2)
    # optimizer = torch.optim.Adam(model.parameters())
    
    # # Assume: 
    # # anchor_images_features: Features of anchor images
    # # positive_texts_features: Features of texts corresponding to anchor images
    # # negative_texts_features: Features of texts not corresponding to anchor images
    
    # anchor_img_embed = model.encode_image(anchor_images_features)
    # positive_text_embed = model.encode_text(positive_texts_features)
    # negative_text_embed = model.encode_text(negative_texts_features)
    
    # loss = triplet_loss(anchor_img_embed, positive_text_embed, negative_text_embed)
    # optimizer.zero_grad()
    # loss.backward()
    # optimizer.step()
    
    # # Symmetrically, you could have text anchors and image positives/negatives.
    # # For contrastive loss (like CLIP):
    # # image_embeddings, text_embeddings = model(image_batch_features, text_batch_features)
    # # logits = image_embeddings @ text_embeddings.t() * logit_scale.exp()
    # # labels = torch.arange(image_embeddings.size(0))
    # # loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)) / 2
    

Interview Examples

What is the 'modality gap' in cross-modal retrieval and how do researchers try to address it?

Explain how a triplet loss function works in the context of learning embeddings for cross-modal retrieval.

Why is Recall@K a commonly used metric for evaluating cross-modal retrieval systems?

Practice Questions

1. How would you implement this in a production environment? Hard

Hint: Consider scalability and efficiency

2. Explain the core concepts of Cross Modal Retrieval Easy

Hint: Think about the fundamental principles

3. What are the practical applications of Cross Modal Retrieval? Medium

Hint: Consider both academic and industry use cases