Neural Architecture Search

Overview

Neural Architecture Search (NAS) is an automated machine learning (AutoML) approach that aims to find the optimal neural network architecture for a given task. Instead of manually designing architectures, NAS uses algorithms to automatically discover the best network structure and hyperparameters.

This approach has led to state-of-the-art results in various domains while reducing the need for expert knowledge in architecture design.

Core Concepts

  • Key Components

    • Search Space: The set of possible architectures
    • Search Strategy: Method to explore the search space
    • Performance Estimation: Evaluating candidate architectures
    • Resource Constraints: Time, compute, and memory limits
  • Common Approaches

    1. Reinforcement Learning:

    • Use RL agent to generate architectures
    • Validation accuracy as reward signal
    • Examples: NASNet, ENAS

    2. Evolutionary Methods:

    • Evolve architectures using genetic algorithms
    • Mutation and crossover operations
    • Examples: AmoebaNet, NEAT

    3. Gradient-based:

    • Continuous relaxation of architecture
    • Optimize using gradient descent
    • Examples: DARTS, SNAS
  • Efficiency Strategies

    • Weight sharing: Reuse weights across candidate architectures (ENAS, DARTS)
    • Early stopping: Identify and terminate unpromising searches
    • Progressive search: Begin with small cells/blocks and gradually increase complexity
    • Proxy tasks: Train on reduced dataset or for fewer epochs during search
    • Hybrid approaches: Combine different search methods for better results
  • Practical Considerations

    • Computational cost: Consider available computational resources and time constraints
    • Search space design: Domain knowledge can help constrain the search space effectively
    • Evaluation strategy: Balance between accuracy of evaluation and speed
    • Transfer learning: Architectures found on one task often transfer well to related tasks
    • Multi-objective optimization: Consider both accuracy and efficiency metrics (latency, memory, etc.)
  • Application Domains

    • Computer Vision: Image classification, object detection, segmentation
    • Natural Language Processing: Transformer architectures, language models
    • Graph Neural Networks: Optimizing message passing architectures
    • Reinforcement Learning: Policy and value network architectures
    • Audio Processing: Speech recognition, music generation
  • Recent Advances

    • Once-for-All Networks: Train a single large network that can be specialized to different hardware platforms
    • Zero-Cost Proxies: Evaluate architecture quality without training using gradient-based metrics
    • Hardware-aware NAS: Incorporate hardware constraints directly into the search process
    • Neural Architecture Transfer: Adapt architectures from one domain to another
    • Dynamic Architectures: Networks that can adapt their structure at inference time based on input complexity

Implementation

  • DARTS Implementation

    Example implementation of Differentiable Architecture Search (DARTS) using PyTorch.
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.autograd import Variable
    import numpy as np
    
    class MixedOp(nn.Module):
        def __init__(self, C, stride):
            super(MixedOp, self).__init__()
            self._ops = nn.ModuleList()
            self.PRIMITIVES = [
                'none',
                'max_pool_3x3',
                'avg_pool_3x3',
                'skip_connect',
                'sep_conv_3x3',
                'sep_conv_5x5',
                'dil_conv_3x3',
                'dil_conv_5x5'
            ]
            
            for primitive in self.PRIMITIVES:
                op = self._create_op(primitive, C, stride)
                if op is not None:
                    self._ops.append(op)
        
        def _create_op(self, primitive, C, stride):
            if primitive == 'none':
                return None
            elif primitive == 'max_pool_3x3':
                return nn.Sequential(
                    nn.MaxPool2d(3, stride=stride, padding=1),
                    nn.BatchNorm2d(C)
                )
            elif primitive == 'avg_pool_3x3':
                return nn.Sequential(
                    nn.AvgPool2d(3, stride=stride, padding=1),
                    nn.BatchNorm2d(C)
                )
            elif primitive == 'skip_connect':
                if stride == 1:
                    return nn.Identity()
                return nn.Sequential(
                    nn.Conv2d(C, C, 1, stride=stride),
                    nn.BatchNorm2d(C)
                )
            elif primitive == 'sep_conv_3x3':
                return nn.Sequential(
                    nn.ReLU(),
                    nn.Conv2d(C, C, 3, stride=stride, padding=1, groups=C),
                    nn.Conv2d(C, C, 1),
                    nn.BatchNorm2d(C)
                )
            elif primitive == 'sep_conv_5x5':
                return nn.Sequential(
                    nn.ReLU(),
                    nn.Conv2d(C, C, 5, stride=stride, padding=2, groups=C),
                    nn.Conv2d(C, C, 1),
                    nn.BatchNorm2d(C)
                )
            elif primitive == 'dil_conv_3x3':
                return nn.Sequential(
                    nn.ReLU(),
                    nn.Conv2d(C, C, 3, stride=stride, padding=2, dilation=2, groups=C),
                    nn.Conv2d(C, C, 1),
                    nn.BatchNorm2d(C)
                )
            elif primitive == 'dil_conv_5x5':
                return nn.Sequential(
                    nn.ReLU(),
                    nn.Conv2d(C, C, 5, stride=stride, padding=4, dilation=2, groups=C),
                    nn.Conv2d(C, C, 1),
                    nn.BatchNorm2d(C)
                )
        
        def forward(self, x, weights):
            return sum(w * op(x) for w, op in zip(weights, self._ops) if op is not None)
    
    class Cell(nn.Module):
        def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
            super(Cell, self).__init__()
            self.reduction = reduction
            self.steps = steps
            
            # Input nodes
            if reduction_prev:
                self.preprocess0 = FactorizedReduce(C_prev_prev, C)
            else:
                self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
            self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
            
            # DAG
            self._ops = nn.ModuleList()
            self._compile(C, reduction)
        
        def _compile(self, C, reduction):
            offset = 0
            self._steps = []
            for i in range(self.steps):
                ops = nn.ModuleList()
                for j in range(2 + i):
                    stride = 2 if reduction and j < 2 else 1
                    op = MixedOp(C, stride)
                    ops.append(op)
                self._ops.append(ops)
                offset += len(ops)
        
        def forward(self, s0, s1, weights):
            s0 = self.preprocess0(s0)
            s1 = self.preprocess1(s1)
            
            states = [s0, s1]
            offset = 0
            
            for i in range(self.steps):
                s = sum(self._ops[i][j](h, weights[offset + j])
                       for j, h in enumerate(states))
                offset += len(states)
                states.append(s)
            
            return torch.cat(states[-self.multiplier:], dim=1)
    
    class Network(nn.Module):
        def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3):
            super(Network, self).__init__()
            self._C = C
            self._num_classes = num_classes
            self._layers = layers
            self._criterion = criterion
            self._steps = steps
            self._multiplier = multiplier
            
            C_curr = stem_multiplier * C
            self.stem = nn.Sequential(
                nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
                nn.BatchNorm2d(C_curr)
            )
            
            C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
            self.cells = nn.ModuleList()
            reduction_prev = False
            
            for i in range(layers):
                if i in [layers//3, 2*layers//3]:
                    C_curr *= 2
                    reduction = True
                else:
                    reduction = False
                
                cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
                reduction_prev = reduction
                self.cells += [cell]
                C_prev_prev, C_prev = C_prev, multiplier * C_curr
            
            self.global_pooling = nn.AdaptiveAvgPool2d(1)
            self.classifier = nn.Linear(C_prev, num_classes)
            
            self._initialize_alphas()
        
        def _initialize_alphas(self):
            k = sum(1 for i in range(self._steps) for n in range(2 + i))
            num_ops = len(self.PRIMITIVES)
            
            self.alphas_normal = Variable(
                1e-3 * torch.randn(k, num_ops),
                requires_grad=True
            )
            self.alphas_reduce = Variable(
                1e-3 * torch.randn(k, num_ops),
                requires_grad=True
            )
            self._arch_parameters = [
                self.alphas_normal,
                self.alphas_reduce,
            ]
        
        def forward(self, input):
            s0 = s1 = self.stem(input)
            for i, cell in enumerate(self.cells):
                if cell.reduction:
                    weights = F.softmax(self.alphas_reduce, dim=-1)
                else:
                    weights = F.softmax(self.alphas_normal, dim=-1)
                s0, s1 = s1, cell(s0, s1, weights)
            
            out = self.global_pooling(s1)
            logits = self.classifier(out.view(out.size(0), -1))
            return logits
    
    class Architect:
        def __init__(self, model, w_momentum, w_weight_decay):
            self.model = model
            self.w_momentum = w_momentum
            self.w_weight_decay = w_weight_decay
            self.optimizer = torch.optim.Adam(
                self.model._arch_parameters,
                lr=3e-4,
                betas=(0.5, 0.999),
                weight_decay=0
            )
        
        def step(self, input_valid, target_valid, eta, network_optimizer):
            self.optimizer.zero_grad()
            self._backward_step(input_valid, target_valid)
            self.optimizer.step()
        
        def _backward_step(self, input_valid, target_valid):
            loss = self.model._criterion(self.model(input_valid), target_valid)
            loss.backward()
    
    # Helper components
    class ReLUConvBN(nn.Module):
        def __init__(self, C_in, C_out, kernel_size, stride, padding):
            super(ReLUConvBN, self).__init__()
            self.op = nn.Sequential(
                nn.ReLU(inplace=False),
                nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
                nn.BatchNorm2d(C_out)
            )
        
        def forward(self, x):
            return self.op(x)
    
    class FactorizedReduce(nn.Module):
        def __init__(self, C_in, C_out):
            super(FactorizedReduce, self).__init__()
            assert C_out % 2 == 0
            self.relu = nn.ReLU(inplace=False)
            self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
            self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
            self.bn = nn.BatchNorm2d(C_out)
        
        def forward(self, x):
            x = self.relu(x)
            out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1)
            out = self.bn(out)
            return out
    

Interview Examples

What is Neural Architecture Search and how does it work?

Explain the basic concept of NAS and the major search strategies.

Compare and contrast evolutionary methods vs. gradient-based methods for NAS

Explain the differences, advantages, and disadvantages of these two NAS approaches.

Practice Questions

1. Explain the core concepts of Neural Architecture Search Easy

Hint: Think about the fundamental principles

2. What are the practical applications of Neural Architecture Search? Medium

Hint: Consider both academic and industry use cases

3. How would you implement this in a production environment? Hard

Hint: Consider scalability and efficiency