Example implementation of Model-Agnostic Meta-Learning (MAML) using PyTorch.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
class MAMLModel(nn.Module):
def __init__(self, input_dim=784, hidden_dim=64, output_dim=10):
super(MAMLModel, self).__init__()
self.net = nn.Sequential(OrderedDict([
('layer1', nn.Linear(input_dim, hidden_dim)),
('relu1', nn.ReLU()),
('layer2', nn.Linear(hidden_dim, hidden_dim)),
('relu2', nn.ReLU()),
('layer3', nn.Linear(hidden_dim, output_dim))
]))
def forward(self, x):
return self.net(x)
def clone_params(self):
return OrderedDict([
(name, param.clone()) for name, param in self.named_parameters()
])
def load_params(self, params):
for name, param_val in params.items():
if name in dict(self.named_parameters()):
dict(self.named_parameters())[name].data.copy_(param_val.data)
class MAML:
def __init__(self, model, inner_lr=0.01, meta_lr=0.001, num_inner_steps=1):
self.model = model
self.inner_lr = inner_lr
self.meta_optimizer = torch.optim.Adam(model.parameters(), lr=meta_lr)
self.num_inner_steps = num_inner_steps
def inner_loop(self, support_x, support_y, params=None):
if params is None:
adapted_params = self.model.clone_params()
else:
adapted_params = params
for _ in range(self.num_inner_steps):
prediction = self.forward_with_params(support_x, adapted_params)
loss = F.cross_entropy(prediction, support_y)
grads = torch.autograd.grad(loss, adapted_params.values(), create_graph=True)
adapted_params = OrderedDict([
(name, param - self.inner_lr * grad)
for ((name, param), grad) in zip(adapted_params.items(), grads)
if param.requires_grad
])
return adapted_params
def forward_with_params(self, x_input, params_dict):
x = x_input
for name, module_layer in self.model.net.named_children():
if isinstance(module_layer, nn.Linear):
weight_name = f'{name}.weight'
bias_name = f'{name}.bias'
x = F.linear(x, params_dict[weight_name], params_dict[bias_name])
elif isinstance(module_layer, nn.ReLU):
x = module_layer(x)
else:
x = module_layer(x)
return x
def meta_train(self, task_batch):
meta_loss_accumulator = 0
self.meta_optimizer.zero_grad()
for task_data in task_batch:
support_x, support_y = task_data['support']
query_x, query_y = task_data['query']
initial_params_for_task = self.model.clone_params()
adapted_params = self.inner_loop(support_x, support_y, params=initial_params_for_task)
query_prediction = self.forward_with_params(query_x, adapted_params)
task_query_loss = F.cross_entropy(query_prediction, query_y)
meta_loss_accumulator += task_query_loss
average_meta_loss = meta_loss_accumulator / len(task_batch)
average_meta_loss.backward()
self.meta_optimizer.step()
return average_meta_loss.item()
class TaskGenerator:
def __init__(self, n_way, k_shot, q_queries):
self.n_way = n_way
self.k_shot = k_shot
self.q_queries = q_queries
def sample_task(self, all_data, all_labels):
unique_labels = np.unique(all_labels)
selected_class_labels = np.random.choice(unique_labels, self.n_way, replace=False)
support_x_list, support_y_list = [], []
query_x_list, query_y_list = [], []
for i, class_label in enumerate(selected_class_labels):
class_indices = np.where(all_labels == class_label)[0]
# Ensure replace=False is only used if there are enough samples
can_replace = len(class_indices) < (self.k_shot + self.q_queries)
selected_indices_for_class = np.random.choice(
class_indices,
self.k_shot + self.q_queries,
replace=can_replace
)
support_indices = selected_indices_for_class[:self.k_shot]
query_indices = selected_indices_for_class[self.k_shot:]
support_x_list.extend(all_data[support_indices])
support_y_list.extend([i] * self.k_shot)
query_x_list.extend(all_data[query_indices])
query_y_list.extend([i] * self.q_queries)
return {
'support': (torch.tensor(np.array(support_x_list), dtype=torch.float32),
torch.tensor(support_y_list, dtype=torch.long)),
'query': (torch.tensor(np.array(query_x_list), dtype=torch.float32),
torch.tensor(query_y_list, dtype=torch.long))
}
def sample_batch(self, all_data, all_labels, meta_batch_size):
return [self.sample_task(all_data, all_labels) for _ in range(meta_batch_size)]
if __name__ == "__main__":
n_way_main = 5
k_shot_main = 1
q_queries_main = 15
meta_batch_size_main = 4
n_meta_iterations_main = 100 # Reduced for quick example run
inner_lr_main = 0.01
meta_lr_main = 0.001
num_inner_steps_main = 1
current_device_main = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_total_classes_main = 20
samples_per_class_main = 30 # Ensure enough samples for k_shot + q_queries without replacement
feature_dim_main = 28*28
example_all_data = np.random.rand(num_total_classes_main * samples_per_class_main, feature_dim_main).astype(np.float32)
example_all_labels = np.array([lbl for lbl in range(num_total_classes_main) for _ in range(samples_per_class_main)])
maml_model_main = MAMLModel(input_dim=feature_dim_main, output_dim=n_way_main).to(current_device_main)
maml_trainer = MAML(maml_model_main, inner_lr=inner_lr_main, meta_lr=meta_lr_main, num_inner_steps=num_inner_steps_main)
task_generator_main = TaskGenerator(n_way_main, k_shot_main, q_queries_main)
print("Starting MAML meta-training example...")
# for iteration_idx in range(n_meta_iterations_main):
# task_batch_data = task_generator_main.sample_batch(example_all_data, example_all_labels, meta_batch_size_main)
# device_task_batch = []
# for task in task_batch_data:
# device_task_batch.append({
# 'support': (task['support'][0].to(current_device_main), task['support'][1].to(current_device_main)),
# 'query': (task['query'][0].to(current_device_main), task['query'][1].to(current_device_main))
# })
# meta_loss_val = maml_trainer.meta_train(device_task_batch)
# if (iteration_idx + 1) % 20 == 0: # Print more frequently for shorter run
# print(f"Iteration {iteration_idx + 1}/{n_meta_iterations_main}, Meta-loss: {meta_loss_val:.4f}")
print("MAML meta-training example finished (actual training loop is commented out).")