from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import requests
import torch
import torch.nn as nn # Required for nn.functional.interpolate
import numpy as np
# Example: Load an image from URL
# url = "http://images.cocodataset.org/val2017/000000039769.jpg" # Example COCO image
# try:
# image = Image.open(requests.get(url, stream=True).raw)
# except Exception as e:
# print(f"Error loading image: {e}. Using a placeholder.")
# image = Image.new('RGB', (600, 400), color = 'green') # Placeholder
# For local testing, create a dummy image
image = Image.new('RGB', (512, 512), color = 'lightgray')
from PIL import ImageDraw
draw = ImageDraw.Draw(image)
# Draw some shapes for potential segmentation
draw.ellipse((50, 50, 200, 200), fill='blue', outline='blue')
draw.rectangle((250, 100, 450, 300), fill='green', outline='green')
draw.line((50,400, 450,450), fill='red', width=10)
# 1. Load a pre-trained SegFormer model and its processor
model_checkpoint = "nvidia/segformer-b0-finetuned-ade-512-512"
processor = SegformerImageProcessor.from_pretrained(model_checkpoint)
model = SegformerForSemanticSegmentation.from_pretrained(model_checkpoint)
# 2. Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# 3. Perform inference
with torch.no_grad():
outputs = model(**inputs)
# 4. Postprocess the outputs
# The model outputs logits. We need to upscale them to the original image size and argmax to get class predictions.
logits = outputs.logits # shape (batch_size, num_classes, height/4, width/4)
# Upsample logits to the original image size
# Note: SegFormer output logits are 1/4th of the input image resolution by default
original_size = image.size[::-1] # (height, width)
upsampled_logits = nn.functional.interpolate(
logits,
size=original_size, # (height, width)
mode='bilinear',
align_corners=False
)
# Get the predicted segmentation map by taking argmax along the class dimension
predicted_segmentation_map = upsampled_logits.argmax(dim=1)[0] # Take the first batch and remove class dim
print(f"Predicted segmentation map shape: {predicted_segmentation_map.shape}")
print(f"Unique class IDs in map: {torch.unique(predicted_segmentation_map)}")
# To visualize (requires matplotlib or other libraries):
# import matplotlib.pyplot as plt
# plt.imshow(predicted_segmentation_map.cpu().numpy())
# plt.title("Predicted Segmentation Map")
# plt.show()
# You can map class IDs to colors for a more meaningful visualization
# The specific class labels and colors would depend on the dataset SegFormer was fine-tuned on (ADE20K in this case)
# For example, model.config.id2label can give you class names.
# print(model.config.id2label[torch.unique(predicted_segmentation_map)[0].item()])