import torch
import torch.nn as nn
import torch.nn.functional as F
class TNet(nn.Module):
"""Input and Feature Transform Network for PointNet."""
def __init__(self, k=3):
super(TNet, self).__init__()
self.k = k
self.conv1 = nn.Conv1d(k, 64, 1)
self.conv2 = nn.Conv1d(64, 128, 1)
self.conv3 = nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k * k)
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(256)
def forward(self, x):
batch_size = x.size(0)
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.relu(self.bn3(self.conv3(x)))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
x = self.relu(self.bn4(self.fc1(x)))
x = self.relu(self.bn5(self.fc2(x)))
x = self.fc3(x)
# Initialize identity matrix for transformation
iden = torch.eye(self.k, dtype=x.dtype, device=x.device).view(1, self.k * self.k).repeat(batch_size, 1)
x = x + iden
x = x.view(-1, self.k, self.k)
return x
class PointNetClassifier(nn.Module):
def __init__(self, num_classes=40, input_dims=3):
super(PointNetClassifier, self).__init__()
self.input_transform = TNet(k=input_dims) # For input points (e.g., 3D coordinates)
self.feature_transform = TNet(k=64) # For learned features
self.conv1 = nn.Conv1d(input_dims, 64, 1)
self.conv2 = nn.Conv1d(64, 64, 1) # Original PointNet might have different channel size here
self.conv3 = nn.Conv1d(64, 64, 1)
self.conv4 = nn.Conv1d(64, 128, 1)
self.conv5 = nn.Conv1d(128, 1024, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(64)
self.bn3 = nn.BatchNorm1d(64)
self.bn4 = nn.BatchNorm1d(128)
self.bn5 = nn.BatchNorm1d(1024)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, num_classes)
self.dropout = nn.Dropout(p=0.3)
self.relu = nn.ReLU()
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, x):
# x shape: (batch_size, num_points, input_dims)
x = x.transpose(2, 1) # (batch_size, input_dims, num_points)
# Input Transform
trans_input = self.input_transform(x)
x = torch.bmm(x.transpose(2, 1), trans_input).transpose(2, 1)
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x))) # First block of shared MLPs
# Feature Transform
trans_feat = self.feature_transform(x)
x = torch.bmm(x.transpose(2,1), trans_feat).transpose(2,1)
point_features = x # Store features before global pooling if needed for other tasks like segmentation
x = self.relu(self.bn3(self.conv3(x))) # Second block of shared MLPs
x = self.relu(self.bn4(self.conv4(x)))
x = self.relu(self.bn5(self.conv5(x))) # (batch_size, 1024, num_points)
# Symmetric function: Max Pooling
x = torch.max(x, 2, keepdim=True)[0] # (batch_size, 1024, 1)
x = x.view(-1, 1024) # Global feature vector (batch_size, 1024)
# Classification MLP
x = self.relu(self.bn1(self.fc1(x))) # Reusing bn1 for 512 dim - better use new BNs
x = self.dropout(x)
x = self.relu(self.bn2(self.fc2(x))) # Reusing bn2 for 256 dim - better use new BNs
x = self.dropout(x)
x = self.fc3(x)
return self.logsoftmax(x), trans_feat # Return logits and feature transform for regularization loss
# Example Usage:
# point_cloud = torch.randn(16, 1024, 3) # Batch of 16 point clouds, 1024 points each, 3D coords
# classifier = PointNetClassifier(num_classes=10)
# logits, feature_transform_matrix = classifier(point_cloud)
# print("Output logits shape:", logits.shape) # (16, 10)
# print("Feature transform matrix shape:", feature_transform_matrix.shape) # (16, 64, 64)