Few Shot Learning

Overview

Few-shot learning is a machine learning paradigm where models must learn to recognize new classes from very few examples, typically 1-5 samples per class. This approach is particularly valuable in scenarios where collecting large amounts of labeled data is expensive, time-consuming, or impractical.

Unlike traditional deep learning that requires extensive training data, few-shot learning aims to leverage prior knowledge and learning strategies to quickly adapt to new tasks with minimal data.

Core Concepts

  • Key Components

    • Support Set: The few examples provided for learning new classes
    • Query Set: The examples we want to classify using knowledge from the support set
    • N-way K-shot: Classification task with N classes and K examples per class
    • Meta-learning: Learning to learn from few examples
  • Common Approaches

    1. Metric Learning:

    • Learn a similarity metric between examples
    • Compare query samples with support set samples
    • Examples: Siamese Networks, Prototypical Networks

    2. Model-based:

    • Modify model architecture to enable quick adaptation
    • Use memory mechanisms or fast weights
    • Examples: MANN, Meta Networks

    3. Optimization-based:

    • Learn optimal initialization for fast adaptation
    • Meta-learn update rules or learning rates
    • Examples: MAML, Reptile
  • Tips for Few-shot Learning

    • Data Augmentation:
      • Crucial for maximizing use of limited samples
      • Use domain-appropriate augmentation techniques
      • Consider using learned augmentation strategies
    • Model Selection:
      • Choose architectures suited for few-shot learning
      • Consider model complexity vs. available data
      • Use pre-trained feature extractors when possible
    • Training Strategy:
      • Use episodic training to simulate few-shot scenarios
      • Implement proper validation protocols
      • Consider meta-learning approaches
  • Common Challenges

    • Overfitting:
      • Risk of overfitting to support set
      • Need for proper regularization
      • Importance of validation strategy
    • Domain Shift:
      • Handling different domains between training and testing
      • Adaptation to new domains with few samples
      • Domain-agnostic feature learning

Implementation

  • Prototypical Networks Implementation

    Example implementation of Prototypical Networks for few-shot learning using PyTorch.
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    import numpy as np
    
    class PrototypicalNetwork(nn.Module):
        def __init__(self, input_dim=784, hidden_dim=64, latent_dim=32):
            super(PrototypicalNetwork, self).__init__()
            self.encoder = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, latent_dim)
            )
        
        def forward(self, x):
            return self.encoder(x)
        
        def compute_prototypes(self, support_set, k_shot):
            # support_set shape: [n_classes, k_shot, feature_dim]
            return support_set.mean(dim=1)  # Average over k shots
        
        def compute_distances(self, prototypes, query_samples):
            # prototypes shape: [n_classes, feature_dim]
            # query_samples shape: [n_queries, feature_dim]
            n_classes = prototypes.shape[0]
            n_queries = query_samples.shape[0]
            
            prototypes = prototypes.unsqueeze(0)  # [1, n_classes, feature_dim]
            query_samples = query_samples.unsqueeze(1)  # [n_queries, 1, feature_dim]
            
            return torch.sum((prototypes - query_samples) ** 2, dim=2)
    
    class FewShotDataset(Dataset):
        def __init__(self, data, labels, n_classes, k_shot, n_queries):
            self.data = data
            self.labels = labels
            self.n_classes = n_classes
            self.k_shot = k_shot
            self.n_queries = n_queries
            
            self.samples_by_class = {}
            for i, label in enumerate(self.labels):
                if label not in self.samples_by_class:
                    self.samples_by_class[label] = []
                self.samples_by_class[label].append(self.data[i])
        
        def __len__(self):
            return 1000  # Arbitrary number of episodes
        
        def __getitem__(self, idx):
            selected_classes = np.random.choice(
                list(self.samples_by_class.keys()),
                self.n_classes,
                replace=False
            )
            
            support_samples = []
            query_samples = []
            support_labels = []
            query_labels = []
            
            for i, cls_label_idx in enumerate(selected_classes): # Use a different variable name for clarity
                cls_samples = self.samples_by_class[cls_label_idx] # Use the correct class label index
                
                selected_indices_from_class = np.random.choice( # Renamed for clarity
                    len(cls_samples),
                    self.k_shot + self.n_queries,
                    replace=False
                )
                
                support_idx = selected_indices_from_class[:self.k_shot]
                query_idx = selected_indices_from_class[self.k_shot:]
                
                support_samples.extend([cls_samples[j] for j in support_idx])
                query_samples.extend([cls_samples[j] for j in query_idx])
                support_labels.extend([i] * self.k_shot) # Use outer loop index 'i' for new label
                query_labels.extend([i] * self.n_queries) # Use outer loop index 'i' for new label
            
            return (
                torch.tensor(np.array(support_samples, dtype=np.float32)), # Ensure data is float for tensor
                torch.tensor(np.array(query_samples, dtype=np.float32)),   # Ensure data is float for tensor
                torch.tensor(support_labels, dtype=torch.long),
                torch.tensor(query_labels, dtype=torch.long)
            )
    
    def train_prototypical_network(
        model,
        train_loader,
        optimizer,
        n_episodes,
        device # Added device argument
    ):
        model.train()
        total_loss = 0
        total_accuracy = 0
        
        for episode_idx in range(n_episodes): # Renamed for clarity
            optimizer.zero_grad()
            
            support_samples, query_samples, support_labels, query_labels = next(iter(train_loader))
            support_samples = support_samples.to(device)
            query_samples = query_samples.to(device)
            # support_labels already torch.long from dataset
            query_labels = query_labels.to(device) # query_labels also torch.long
            
            support_embeddings = model(support_samples)
            query_embeddings = model(query_samples)
            
            n_classes_runtime = len(torch.unique(support_labels)) # Renamed for clarity
            k_shot_runtime = support_samples.shape[0] // n_classes_runtime # Renamed
            support_embeddings = support_embeddings.view(n_classes_runtime, k_shot_runtime, -1)
            
            prototypes = model.compute_prototypes(support_embeddings, k_shot_runtime) # Pass k_shot_runtime
            distances = model.compute_distances(prototypes, query_embeddings)
            
            log_probas = F.log_softmax(-distances, dim=1)
            loss = F.nll_loss(log_probas, query_labels)
            
            loss.backward()
            optimizer.step()
            
            _, predictions = torch.min(distances, dim=1)
            accuracy = torch.mean((predictions == query_labels).float())
            
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            
            if (episode_idx + 1) % 100 == 0:
                print(f"Episode {episode_idx + 1}/{n_episodes}")
                print(f"Average Loss: {total_loss / 100:.4f}")
                print(f"Average Accuracy: {total_accuracy / 100:.4f}")
                total_loss = 0
                total_accuracy = 0
    
    if __name__ == "__main__":
        # Hyperparameters
        n_classes_main = 5
        k_shot_main = 1
        n_queries_main = 15
        n_episodes_main = 1000
        learning_rate_main = 0.001
        
        # Setup device
        current_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Renamed
        
        # Dummy data for example
        # Create some dummy data: 10 classes, 50 samples each, 28x28 features (like MNIST)
        num_total_classes = 10
        samples_per_class = 50
        feature_dim = 28*28 
        all_data = np.random.rand(num_total_classes * samples_per_class, feature_dim).astype(np.float32)
        all_labels = np.array([i for i in range(num_total_classes) for _ in range(samples_per_class)])
        
        # Create model and optimizer
        proto_model = PrototypicalNetwork(input_dim=feature_dim).to(current_device) # Renamed
        adam_optimizer = torch.optim.Adam(proto_model.parameters(), lr=learning_rate_main) # Renamed
        
        # Create dataset and dataloader
        fsl_dataset = FewShotDataset(all_data, all_labels, n_classes_main, k_shot_main, n_queries_main) # Renamed
        fsl_train_loader = DataLoader(fsl_dataset, batch_size=1, shuffle=True) # Renamed
        
        # Train model
        train_prototypical_network(proto_model, fsl_train_loader, adam_optimizer, n_episodes_main, current_device)
    

Interview Examples

What is few-shot learning and how does it differ from traditional learning?

Explain the concept of few-shot learning and its key differences from traditional machine learning approaches.

Practice Questions

1. What are the practical applications of Few Shot Learning? Medium

Hint: Consider both academic and industry use cases

2. How would you implement this in a production environment? Hard

Hint: Consider scalability and efficiency

3. Explain the core concepts of Few Shot Learning Easy

Hint: Think about the fundamental principles