Example implementation of transfer learning using a pre-trained ResNet model for a new classification task.
                                
                                
                                    
import torch
import torch.nn as nn
import torchvision.models as models
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
class TransferLearningModel:
    def __init__(self, num_classes, model_name='resnet50', feature_extract=True):
        if model_name == 'resnet50':
            self.model = models.resnet50(pretrained=True)
        elif model_name == 'resnet18':
            self.model = models.resnet18(pretrained=True)
        else:
            raise ValueError(f"Unsupported model: {model_name}")
        
        if feature_extract:
            for param in self.model.parameters():
                param.requires_grad = False
        
        num_features = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    
    def get_trainable_params(self):
        return [p for p in self.model.parameters() if p.requires_grad]
class TransferLearningTrainer:
    def __init__(self, model, device, learning_rate=0.001):
        self.model = model.to(device)
        self.device = device
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = Adam(self.model.get_trainable_params(), lr=learning_rate)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', patience=3)
    
    def train_epoch(self, train_loader):
        self.model.train()
        total_loss = 0
        correct_preds = 0
        total_samples = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total_samples += labels.size(0)
            correct_preds += predicted.eq(labels).sum().item()
        
        return total_loss / len(train_loader), 100. * correct_preds / total_samples
    
    def validate(self, val_loader):
        self.model.eval()
        total_loss = 0
        correct_preds = 0
        total_samples = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total_samples += labels.size(0)
                correct_preds += predicted.eq(labels).sum().item()
        
        avg_loss = total_loss / len(val_loader)
        accuracy = 100. * correct_preds / total_samples
        self.scheduler.step(avg_loss)
        return avg_loss, accuracy
def progressive_unfreezing(model, num_epochs_per_unfreeze=5):
    layers_to_unfreeze = []
    for name, param in model.named_parameters():
        if 'fc' not in name:
            param.requires_grad = False
            layers_to_unfreeze.append(name)
    
    layer_groups = {}
    for name in layers_to_unfreeze:
        block_candidate = name.split('.')[0]
        if any(x.isdigit() for x in block_candidate):
            if block_candidate not in layer_groups:
                layer_groups[block_candidate] = []
            layer_groups[block_candidate].append(name)
        else:
            if 'general_pretrained' not in layer_groups:
                 layer_groups['general_pretrained'] = []
            layer_groups['general_pretrained'].append(name)
            
    unfreeze_schedule = []
    block_order = sorted(list(layer_groups.keys()), key=lambda x: (x != 'general_pretrained', x))
    current_epoch_offset = 0
    for block_name in reversed(block_order):
        unfreeze_schedule.append({
            'epoch_to_activate': current_epoch_offset,
            'layers_to_enable_grad': layer_groups[block_name]
        })
        current_epoch_offset += num_epochs_per_unfreeze
    
    return unfreeze_schedule
if __name__ == "__main__":
    current_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes_main = 10
    
    tl_model = TransferLearningModel(num_classes_main, model_name='resnet18', feature_extract=True)
    tl_trainer = TransferLearningTrainer(tl_model.model, current_device, learning_rate=0.001)
    
    unfreeze_plan = progressive_unfreezing(tl_model.model, num_epochs_per_unfreeze=3)
    print("Progressive Unfreezing Plan:", unfreeze_plan)
    class DummyDataset(torch.utils.data.Dataset):
        def __init__(self, num_samples=100, num_features=3*224*224, num_classes=10):
            self.num_samples = num_samples
            self.data = torch.randn(num_samples, num_features)
            if num_features == 3*224*224:
                self.data = self.data.reshape(num_samples, 3, 224, 224)
            self.labels = torch.randint(0, num_classes, (num_samples,))
        def __len__(self):
            return self.num_samples
        def __getitem__(self, idx):
            return self.data[idx], self.labels[idx]
    input_features = tl_model.model.conv1.in_channels * tl_model.model.conv1.kernel_size[0] * tl_model.model.conv1.kernel_size[1]
    dummy_train_loader = torch.utils.data.DataLoader(DummyDataset(num_classes=num_classes_main), batch_size=4)
    dummy_val_loader = torch.utils.data.DataLoader(DummyDataset(num_classes=num_classes_main), batch_size=4)
    num_total_epochs = 15
    
    # Training loop would go here
    # for current_epoch in range(num_total_epochs):
    #     print(f"Epoch {current_epoch + 1}/{num_total_epochs}")
        
    #     for schedule_item in unfreeze_plan:
    #         if current_epoch == schedule_item['epoch_to_activate']:
    #             print(f"Unfreezing layers: {schedule_item['layers_to_enable_grad']}")
    #             for layer_name_to_unfreeze in schedule_item['layers_to_enable_grad']:
    #                 for name, param in tl_model.model.named_parameters():
    #                     if name == layer_name_to_unfreeze:
    #                         param.requires_grad = True
    #             tl_trainer.optimizer = Adam(tl_model.model.get_trainable_params(), lr=0.0001) 
    #             print("Optimizer updated with new trainable parameters.")
    #     train_loss, train_acc = tl_trainer.train_epoch(dummy_train_loader)
    #     val_loss, val_acc = tl_trainer.validate(dummy_val_loader)
        
    #     print(f"  Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%")
    #     print(f"  Val Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%")