Example implementation of Differentiable Architecture Search (DARTS) using PyTorch.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
class MixedOp(nn.Module):
def __init__(self, C, stride):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
self.PRIMITIVES = [
'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5'
]
for primitive in self.PRIMITIVES:
op = self._create_op(primitive, C, stride)
if op is not None:
self._ops.append(op)
def _create_op(self, primitive, C, stride):
if primitive == 'none':
return None
elif primitive == 'max_pool_3x3':
return nn.Sequential(
nn.MaxPool2d(3, stride=stride, padding=1),
nn.BatchNorm2d(C)
)
elif primitive == 'avg_pool_3x3':
return nn.Sequential(
nn.AvgPool2d(3, stride=stride, padding=1),
nn.BatchNorm2d(C)
)
elif primitive == 'skip_connect':
if stride == 1:
return nn.Identity()
return nn.Sequential(
nn.Conv2d(C, C, 1, stride=stride),
nn.BatchNorm2d(C)
)
elif primitive == 'sep_conv_3x3':
return nn.Sequential(
nn.ReLU(),
nn.Conv2d(C, C, 3, stride=stride, padding=1, groups=C),
nn.Conv2d(C, C, 1),
nn.BatchNorm2d(C)
)
elif primitive == 'sep_conv_5x5':
return nn.Sequential(
nn.ReLU(),
nn.Conv2d(C, C, 5, stride=stride, padding=2, groups=C),
nn.Conv2d(C, C, 1),
nn.BatchNorm2d(C)
)
elif primitive == 'dil_conv_3x3':
return nn.Sequential(
nn.ReLU(),
nn.Conv2d(C, C, 3, stride=stride, padding=2, dilation=2, groups=C),
nn.Conv2d(C, C, 1),
nn.BatchNorm2d(C)
)
elif primitive == 'dil_conv_5x5':
return nn.Sequential(
nn.ReLU(),
nn.Conv2d(C, C, 5, stride=stride, padding=4, dilation=2, groups=C),
nn.Conv2d(C, C, 1),
nn.BatchNorm2d(C)
)
def forward(self, x, weights):
return sum(w * op(x) for w, op in zip(weights, self._ops) if op is not None)
class Cell(nn.Module):
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
self.reduction = reduction
self.steps = steps
# Input nodes
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
# DAG
self._ops = nn.ModuleList()
self._compile(C, reduction)
def _compile(self, C, reduction):
offset = 0
self._steps = []
for i in range(self.steps):
ops = nn.ModuleList()
for j in range(2 + i):
stride = 2 if reduction and j < 2 else 1
op = MixedOp(C, stride)
ops.append(op)
self._ops.append(ops)
offset += len(ops)
def forward(self, s0, s1, weights):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
offset = 0
for i in range(self.steps):
s = sum(self._ops[i][j](h, weights[offset + j])
for j, h in enumerate(states))
offset += len(states)
states.append(s)
return torch.cat(states[-self.multiplier:], dim=1)
class Network(nn.Module):
def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3):
super(Network, self).__init__()
self._C = C
self._num_classes = num_classes
self._layers = layers
self._criterion = criterion
self._steps = steps
self._multiplier = multiplier
C_curr = stem_multiplier * C
self.stem = nn.Sequential(
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in [layers//3, 2*layers//3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, multiplier * C_curr
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self._initialize_alphas()
def _initialize_alphas(self):
k = sum(1 for i in range(self._steps) for n in range(2 + i))
num_ops = len(self.PRIMITIVES)
self.alphas_normal = Variable(
1e-3 * torch.randn(k, num_ops),
requires_grad=True
)
self.alphas_reduce = Variable(
1e-3 * torch.randn(k, num_ops),
requires_grad=True
)
self._arch_parameters = [
self.alphas_normal,
self.alphas_reduce,
]
def forward(self, input):
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
if cell.reduction:
weights = F.softmax(self.alphas_reduce, dim=-1)
else:
weights = F.softmax(self.alphas_normal, dim=-1)
s0, s1 = s1, cell(s0, s1, weights)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
return logits
class Architect:
def __init__(self, model, w_momentum, w_weight_decay):
self.model = model
self.w_momentum = w_momentum
self.w_weight_decay = w_weight_decay
self.optimizer = torch.optim.Adam(
self.model._arch_parameters,
lr=3e-4,
betas=(0.5, 0.999),
weight_decay=0
)
def step(self, input_valid, target_valid, eta, network_optimizer):
self.optimizer.zero_grad()
self._backward_step(input_valid, target_valid)
self.optimizer.step()
def _backward_step(self, input_valid, target_valid):
loss = self.model._criterion(self.model(input_valid), target_valid)
loss.backward()
# Helper components
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(C_out)
)
def forward(self, x):
return self.op(x)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
self.relu = nn.ReLU(inplace=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out)
def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1)
out = self.bn(out)
return out