# Flash Attention Implementation
def flash_attention(q, k, v, sm_scale=None):
# q, k, v: (batch_size, num_heads, seq_len, head_dim)
# Assumes blocked matrix multiplication is available
batch_size, num_heads, seq_len, head_dim = q.shape
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(head_dim)
# Initialize output and attention statistics
out = torch.zeros_like(q)
m = torch.zeros(batch_size, num_heads, seq_len, device=q.device)
l = torch.zeros(batch_size, num_heads, seq_len, device=q.device)
# Process in blocks to save memory
block_size = min(256, seq_len)
for block_start in range(0, seq_len, block_size):
block_end = min(block_start + block_size, seq_len)
# Load current block of keys and values
k_block = k[:, :, block_start:block_end]
v_block = v[:, :, block_start:block_end]
# Compute attention scores for current block
current_scores = torch.matmul(q, k_block.transpose(-2, -1)) * sm_scale
# Update running statistics
m_new = torch.maximum(m, current_scores.max(dim=-1)[0])
exp_diff = torch.exp(m.unsqueeze(-1) - m_new.unsqueeze(-1))
# Update output
current_exp = torch.exp(current_scores - m_new.unsqueeze(-1))
l_new = l * exp_diff + current_exp.sum(dim=-1)
out = out * exp_diff.unsqueeze(-1) + torch.matmul(current_exp, v_block)
m = m_new
l = l_new
return out / l.unsqueeze(-1)
# Memory-Efficient Transformer with Gradient Checkpointing
class MemoryEfficientTransformer(nn.Module):
def __init__(self, config):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayer(config) for _ in range(config.num_layers)
])
def forward(self, x, use_checkpointing=True):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if use_checkpointing:
for layer in self.layers:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
x
)
else:
for layer in self.layers:
x = layer(x)
return x