Transfer Learning

Overview

Transfer Learning is a machine learning technique where a model developed for one task is reused as the starting point for a model on a second task. It's particularly popular in deep learning where pre-trained models are used as feature extractors or fine-tuned for new tasks, saving significant computational resources and reducing the need for large datasets.

This approach is especially powerful when the target task has limited training data but shares some commonality with the source task where the model was originally trained.

Core Concepts

  • Types of Transfer Learning

    • Feature Extraction: Using a pre-trained model as a fixed feature extractor by removing the last few layers and adding new ones for the target task.
    • Fine-tuning: Continuing training of a pre-trained model on a new task, often with a smaller learning rate to preserve useful features.
    • Domain Adaptation: Adapting a model trained on one domain to work well on a different but related domain.
  • Common Strategies

    1. Feature Extraction:

    • Remove the last layer(s) of a pre-trained network
    • Use the remaining network as a fixed feature extractor
    • Train only a new classifier on these features
    • Best when target task is similar to original task and dataset is small

    2. Fine-tuning:

    • Start with a pre-trained model
    • Replace and retrain the classifier on new task
    • Optionally fine-tune some or all of the pre-trained layers
    • Use smaller learning rate to preserve learned features

    3. Progressive Fine-tuning:

    • Gradually unfreeze and fine-tune layers from top to bottom
    • Use discriminative learning rates (higher for new layers, lower for pre-trained)
    • Helps prevent catastrophic forgetting
  • Tips for Transfer Learning

    • Model Selection:
      • Choose pre-trained models from similar domains
      • Larger models can learn more generalizable features
      • Consider computational requirements for your deployment scenario
    • Fine-tuning Strategy:
      • For small datasets: freeze most layers
      • For larger datasets: fine-tune more layers
      • Use smaller learning rates for pre-trained layers
    • Data Considerations:
      • Match preprocessing to original model's training data
      • Augment target dataset appropriately
      • Address domain shift if source and target domains differ
  • Common Challenges

    • Negative Transfer:
      • Source task knowledge impairs target task performance
      • Often occurs when domains or tasks are too dissimilar
      • Monitor validation performance closely
    • Catastrophic Forgetting:
      • Fine-tuning erases useful pre-trained representations
      • Can occur with high learning rates
      • Gradual fine-tuning strategies help mitigate

Implementation

  • Transfer Learning with PyTorch

    Example implementation of transfer learning using a pre-trained ResNet model for a new classification task.
    
    import torch
    import torch.nn as nn
    import torchvision.models as models
    from torch.optim import Adam
    from torch.optim.lr_scheduler import ReduceLROnPlateau
    
    class TransferLearningModel:
        def __init__(self, num_classes, model_name='resnet50', feature_extract=True):
            if model_name == 'resnet50':
                self.model = models.resnet50(pretrained=True)
            elif model_name == 'resnet18':
                self.model = models.resnet18(pretrained=True)
            else:
                raise ValueError(f"Unsupported model: {model_name}")
            
            if feature_extract:
                for param in self.model.parameters():
                    param.requires_grad = False
            
            num_features = self.model.fc.in_features
            self.model.fc = nn.Sequential(
                nn.Linear(num_features, 512),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(512, num_classes)
            )
        
        def get_trainable_params(self):
            return [p for p in self.model.parameters() if p.requires_grad]
    
    class TransferLearningTrainer:
        def __init__(self, model, device, learning_rate=0.001):
            self.model = model.to(device)
            self.device = device
            self.criterion = nn.CrossEntropyLoss()
            self.optimizer = Adam(self.model.get_trainable_params(), lr=learning_rate)
            self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', patience=3)
        
        def train_epoch(self, train_loader):
            self.model.train()
            total_loss = 0
            correct_preds = 0
            total_samples = 0
            
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                
                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total_samples += labels.size(0)
                correct_preds += predicted.eq(labels).sum().item()
            
            return total_loss / len(train_loader), 100. * correct_preds / total_samples
        
        def validate(self, val_loader):
            self.model.eval()
            total_loss = 0
            correct_preds = 0
            total_samples = 0
            
            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(self.device), labels.to(self.device)
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, labels)
                    
                    total_loss += loss.item()
                    _, predicted = outputs.max(1)
                    total_samples += labels.size(0)
                    correct_preds += predicted.eq(labels).sum().item()
            
            avg_loss = total_loss / len(val_loader)
            accuracy = 100. * correct_preds / total_samples
            self.scheduler.step(avg_loss)
            return avg_loss, accuracy
    
    def progressive_unfreezing(model, num_epochs_per_unfreeze=5):
        layers_to_unfreeze = []
        for name, param in model.named_parameters():
            if 'fc' not in name:
                param.requires_grad = False
                layers_to_unfreeze.append(name)
        
        layer_groups = {}
        for name in layers_to_unfreeze:
            block_candidate = name.split('.')[0]
            if any(x.isdigit() for x in block_candidate):
                if block_candidate not in layer_groups:
                    layer_groups[block_candidate] = []
                layer_groups[block_candidate].append(name)
            else:
                if 'general_pretrained' not in layer_groups:
                     layer_groups['general_pretrained'] = []
                layer_groups['general_pretrained'].append(name)
                
        unfreeze_schedule = []
        block_order = sorted(list(layer_groups.keys()), key=lambda x: (x != 'general_pretrained', x))
    
        current_epoch_offset = 0
        for block_name in reversed(block_order):
            unfreeze_schedule.append({
                'epoch_to_activate': current_epoch_offset,
                'layers_to_enable_grad': layer_groups[block_name]
            })
            current_epoch_offset += num_epochs_per_unfreeze
        
        return unfreeze_schedule
    
    if __name__ == "__main__":
        current_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        num_classes_main = 10
        
        tl_model = TransferLearningModel(num_classes_main, model_name='resnet18', feature_extract=True)
        tl_trainer = TransferLearningTrainer(tl_model.model, current_device, learning_rate=0.001)
        
        unfreeze_plan = progressive_unfreezing(tl_model.model, num_epochs_per_unfreeze=3)
        print("Progressive Unfreezing Plan:", unfreeze_plan)
    
        class DummyDataset(torch.utils.data.Dataset):
            def __init__(self, num_samples=100, num_features=3*224*224, num_classes=10):
                self.num_samples = num_samples
                self.data = torch.randn(num_samples, num_features)
                if num_features == 3*224*224:
                    self.data = self.data.reshape(num_samples, 3, 224, 224)
                self.labels = torch.randint(0, num_classes, (num_samples,))
            def __len__(self):
                return self.num_samples
            def __getitem__(self, idx):
                return self.data[idx], self.labels[idx]
    
        input_features = tl_model.model.conv1.in_channels * tl_model.model.conv1.kernel_size[0] * tl_model.model.conv1.kernel_size[1]
    
        dummy_train_loader = torch.utils.data.DataLoader(DummyDataset(num_classes=num_classes_main), batch_size=4)
        dummy_val_loader = torch.utils.data.DataLoader(DummyDataset(num_classes=num_classes_main), batch_size=4)
    
        num_total_epochs = 15
        
        # Training loop would go here
        # for current_epoch in range(num_total_epochs):
        #     print(f"Epoch {current_epoch + 1}/{num_total_epochs}")
            
        #     for schedule_item in unfreeze_plan:
        #         if current_epoch == schedule_item['epoch_to_activate']:
        #             print(f"Unfreezing layers: {schedule_item['layers_to_enable_grad']}")
        #             for layer_name_to_unfreeze in schedule_item['layers_to_enable_grad']:
        #                 for name, param in tl_model.model.named_parameters():
        #                     if name == layer_name_to_unfreeze:
        #                         param.requires_grad = True
        #             tl_trainer.optimizer = Adam(tl_model.model.get_trainable_params(), lr=0.0001) 
        #             print("Optimizer updated with new trainable parameters.")
    
        #     train_loss, train_acc = tl_trainer.train_epoch(dummy_train_loader)
        #     val_loss, val_acc = tl_trainer.validate(dummy_val_loader)
            
        #     print(f"  Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%")
        #     print(f"  Val Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%")
    

Interview Examples

What is transfer learning and when should it be used?

Explain the concept of transfer learning and when it's appropriate to apply it in deep learning projects.

Practice Questions

1. How would you implement this in a production environment? Hard

Hint: Consider scalability and efficiency

2. What are the practical applications of Transfer Learning? Medium

Hint: Consider both academic and industry use cases

3. Explain the core concepts of Transfer Learning Easy

Hint: Think about the fundamental principles