Example implementation of Prototypical Networks for few-shot learning using PyTorch.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
class PrototypicalNetwork(nn.Module):
def __init__(self, input_dim=784, hidden_dim=64, latent_dim=32):
super(PrototypicalNetwork, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim)
)
def forward(self, x):
return self.encoder(x)
def compute_prototypes(self, support_set, k_shot):
# support_set shape: [n_classes, k_shot, feature_dim]
return support_set.mean(dim=1) # Average over k shots
def compute_distances(self, prototypes, query_samples):
# prototypes shape: [n_classes, feature_dim]
# query_samples shape: [n_queries, feature_dim]
n_classes = prototypes.shape[0]
n_queries = query_samples.shape[0]
prototypes = prototypes.unsqueeze(0) # [1, n_classes, feature_dim]
query_samples = query_samples.unsqueeze(1) # [n_queries, 1, feature_dim]
return torch.sum((prototypes - query_samples) ** 2, dim=2)
class FewShotDataset(Dataset):
def __init__(self, data, labels, n_classes, k_shot, n_queries):
self.data = data
self.labels = labels
self.n_classes = n_classes
self.k_shot = k_shot
self.n_queries = n_queries
self.samples_by_class = {}
for i, label in enumerate(self.labels):
if label not in self.samples_by_class:
self.samples_by_class[label] = []
self.samples_by_class[label].append(self.data[i])
def __len__(self):
return 1000 # Arbitrary number of episodes
def __getitem__(self, idx):
selected_classes = np.random.choice(
list(self.samples_by_class.keys()),
self.n_classes,
replace=False
)
support_samples = []
query_samples = []
support_labels = []
query_labels = []
for i, cls_label_idx in enumerate(selected_classes): # Use a different variable name for clarity
cls_samples = self.samples_by_class[cls_label_idx] # Use the correct class label index
selected_indices_from_class = np.random.choice( # Renamed for clarity
len(cls_samples),
self.k_shot + self.n_queries,
replace=False
)
support_idx = selected_indices_from_class[:self.k_shot]
query_idx = selected_indices_from_class[self.k_shot:]
support_samples.extend([cls_samples[j] for j in support_idx])
query_samples.extend([cls_samples[j] for j in query_idx])
support_labels.extend([i] * self.k_shot) # Use outer loop index 'i' for new label
query_labels.extend([i] * self.n_queries) # Use outer loop index 'i' for new label
return (
torch.tensor(np.array(support_samples, dtype=np.float32)), # Ensure data is float for tensor
torch.tensor(np.array(query_samples, dtype=np.float32)), # Ensure data is float for tensor
torch.tensor(support_labels, dtype=torch.long),
torch.tensor(query_labels, dtype=torch.long)
)
def train_prototypical_network(
model,
train_loader,
optimizer,
n_episodes,
device # Added device argument
):
model.train()
total_loss = 0
total_accuracy = 0
for episode_idx in range(n_episodes): # Renamed for clarity
optimizer.zero_grad()
support_samples, query_samples, support_labels, query_labels = next(iter(train_loader))
support_samples = support_samples.to(device)
query_samples = query_samples.to(device)
# support_labels already torch.long from dataset
query_labels = query_labels.to(device) # query_labels also torch.long
support_embeddings = model(support_samples)
query_embeddings = model(query_samples)
n_classes_runtime = len(torch.unique(support_labels)) # Renamed for clarity
k_shot_runtime = support_samples.shape[0] // n_classes_runtime # Renamed
support_embeddings = support_embeddings.view(n_classes_runtime, k_shot_runtime, -1)
prototypes = model.compute_prototypes(support_embeddings, k_shot_runtime) # Pass k_shot_runtime
distances = model.compute_distances(prototypes, query_embeddings)
log_probas = F.log_softmax(-distances, dim=1)
loss = F.nll_loss(log_probas, query_labels)
loss.backward()
optimizer.step()
_, predictions = torch.min(distances, dim=1)
accuracy = torch.mean((predictions == query_labels).float())
total_loss += loss.item()
total_accuracy += accuracy.item()
if (episode_idx + 1) % 100 == 0:
print(f"Episode {episode_idx + 1}/{n_episodes}")
print(f"Average Loss: {total_loss / 100:.4f}")
print(f"Average Accuracy: {total_accuracy / 100:.4f}")
total_loss = 0
total_accuracy = 0
if __name__ == "__main__":
# Hyperparameters
n_classes_main = 5
k_shot_main = 1
n_queries_main = 15
n_episodes_main = 1000
learning_rate_main = 0.001
# Setup device
current_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Renamed
# Dummy data for example
# Create some dummy data: 10 classes, 50 samples each, 28x28 features (like MNIST)
num_total_classes = 10
samples_per_class = 50
feature_dim = 28*28
all_data = np.random.rand(num_total_classes * samples_per_class, feature_dim).astype(np.float32)
all_labels = np.array([i for i in range(num_total_classes) for _ in range(samples_per_class)])
# Create model and optimizer
proto_model = PrototypicalNetwork(input_dim=feature_dim).to(current_device) # Renamed
adam_optimizer = torch.optim.Adam(proto_model.parameters(), lr=learning_rate_main) # Renamed
# Create dataset and dataloader
fsl_dataset = FewShotDataset(all_data, all_labels, n_classes_main, k_shot_main, n_queries_main) # Renamed
fsl_train_loader = DataLoader(fsl_dataset, batch_size=1, shuffle=True) # Renamed
# Train model
train_prototypical_network(proto_model, fsl_train_loader, adam_optimizer, n_episodes_main, current_device)