import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = ConvBlock(in_channels, out_channels, stride=stride)
self.conv2 = ConvBlock(out_channels, out_channels)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return F.relu(out)
class ModernCNN(nn.Module):
def __init__(self, num_classes=1000, input_channels=3):
super(ModernCNN, self).__init__()
# Initial layers
self.conv1 = ConvBlock(input_channels, 64, kernel_size=7, stride=2, padding=3)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Residual layers
self.layer1 = self._make_layer(64, 64, 3)
self.layer2 = self._make_layer(256, 128, 4, stride=2)
self.layer3 = self._make_layer(512, 256, 6, stride=2)
self.layer4 = self._make_layer(1024, 512, 3, stride=2)
# Global pooling and classifier
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(2048, num_classes)
def _make_layer(self, in_channels, out_channels, blocks, stride=1):
layers = []
downsample = None
if stride != 1 or in_channels != out_channels * 4:
downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels * 4, 1, stride),
nn.BatchNorm2d(out_channels * 4)
)
layers.append(ResidualBlock(in_channels, out_channels, stride, downsample))
for _ in range(1, blocks):
layers.append(ResidualBlock(out_channels * 4, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x