Example implementation of transfer learning using a pre-trained ResNet model for a new classification task.
import torch
import torch.nn as nn
import torchvision.models as models
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
class TransferLearningModel:
def __init__(self, num_classes, model_name='resnet50', feature_extract=True):
if model_name == 'resnet50':
self.model = models.resnet50(pretrained=True)
elif model_name == 'resnet18':
self.model = models.resnet18(pretrained=True)
else:
raise ValueError(f"Unsupported model: {model_name}")
if feature_extract:
for param in self.model.parameters():
param.requires_grad = False
num_features = self.model.fc.in_features
self.model.fc = nn.Sequential(
nn.Linear(num_features, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def get_trainable_params(self):
return [p for p in self.model.parameters() if p.requires_grad]
class TransferLearningTrainer:
def __init__(self, model, device, learning_rate=0.001):
self.model = model.to(device)
self.device = device
self.criterion = nn.CrossEntropyLoss()
self.optimizer = Adam(self.model.get_trainable_params(), lr=learning_rate)
self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', patience=3)
def train_epoch(self, train_loader):
self.model.train()
total_loss = 0
correct_preds = 0
total_samples = 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(self.device), labels.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
_, predicted = outputs.max(1)
total_samples += labels.size(0)
correct_preds += predicted.eq(labels).sum().item()
return total_loss / len(train_loader), 100. * correct_preds / total_samples
def validate(self, val_loader):
self.model.eval()
total_loss = 0
correct_preds = 0
total_samples = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(self.device), labels.to(self.device)
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
total_loss += loss.item()
_, predicted = outputs.max(1)
total_samples += labels.size(0)
correct_preds += predicted.eq(labels).sum().item()
avg_loss = total_loss / len(val_loader)
accuracy = 100. * correct_preds / total_samples
self.scheduler.step(avg_loss)
return avg_loss, accuracy
def progressive_unfreezing(model, num_epochs_per_unfreeze=5):
layers_to_unfreeze = []
for name, param in model.named_parameters():
if 'fc' not in name:
param.requires_grad = False
layers_to_unfreeze.append(name)
layer_groups = {}
for name in layers_to_unfreeze:
block_candidate = name.split('.')[0]
if any(x.isdigit() for x in block_candidate):
if block_candidate not in layer_groups:
layer_groups[block_candidate] = []
layer_groups[block_candidate].append(name)
else:
if 'general_pretrained' not in layer_groups:
layer_groups['general_pretrained'] = []
layer_groups['general_pretrained'].append(name)
unfreeze_schedule = []
block_order = sorted(list(layer_groups.keys()), key=lambda x: (x != 'general_pretrained', x))
current_epoch_offset = 0
for block_name in reversed(block_order):
unfreeze_schedule.append({
'epoch_to_activate': current_epoch_offset,
'layers_to_enable_grad': layer_groups[block_name]
})
current_epoch_offset += num_epochs_per_unfreeze
return unfreeze_schedule
if __name__ == "__main__":
current_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes_main = 10
tl_model = TransferLearningModel(num_classes_main, model_name='resnet18', feature_extract=True)
tl_trainer = TransferLearningTrainer(tl_model.model, current_device, learning_rate=0.001)
unfreeze_plan = progressive_unfreezing(tl_model.model, num_epochs_per_unfreeze=3)
print("Progressive Unfreezing Plan:", unfreeze_plan)
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, num_samples=100, num_features=3*224*224, num_classes=10):
self.num_samples = num_samples
self.data = torch.randn(num_samples, num_features)
if num_features == 3*224*224:
self.data = self.data.reshape(num_samples, 3, 224, 224)
self.labels = torch.randint(0, num_classes, (num_samples,))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
input_features = tl_model.model.conv1.in_channels * tl_model.model.conv1.kernel_size[0] * tl_model.model.conv1.kernel_size[1]
dummy_train_loader = torch.utils.data.DataLoader(DummyDataset(num_classes=num_classes_main), batch_size=4)
dummy_val_loader = torch.utils.data.DataLoader(DummyDataset(num_classes=num_classes_main), batch_size=4)
num_total_epochs = 15
# Training loop would go here
# for current_epoch in range(num_total_epochs):
# print(f"Epoch {current_epoch + 1}/{num_total_epochs}")
# for schedule_item in unfreeze_plan:
# if current_epoch == schedule_item['epoch_to_activate']:
# print(f"Unfreezing layers: {schedule_item['layers_to_enable_grad']}")
# for layer_name_to_unfreeze in schedule_item['layers_to_enable_grad']:
# for name, param in tl_model.model.named_parameters():
# if name == layer_name_to_unfreeze:
# param.requires_grad = True
# tl_trainer.optimizer = Adam(tl_model.model.get_trainable_params(), lr=0.0001)
# print("Optimizer updated with new trainable parameters.")
# train_loss, train_acc = tl_trainer.train_epoch(dummy_train_loader)
# val_loss, val_acc = tl_trainer.validate(dummy_val_loader)
# print(f" Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%")
# print(f" Val Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%")