On this page
article
Computer Vision with PyTorch
Build computer vision applications with PyTorch and torchvision — image classification, transfer learning, and data augmentation.
Computer vision enables machines to understand images. PyTorch and torchvision provide pre-trained models and tools for classification, detection, and segmentation.
Setup
pip install torch torchvision matplotlib
Load and Visualize Images
import torch
from torchvision import datasets, transforms
from torchvision.io import read_image
import matplotlib.pyplot as plt
img = read_image("photo.jpg")
print(img.shape) # [3, H, W] — channels first
plt.imshow(img.permute(1, 2, 0))
plt.axis("off")
plt.show()
Image Transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
train_dataset = datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=64, shuffle=True, num_workers=2
)
Pre-trained Model for Classification
from torchvision import models
import torch.nn as nn
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.eval()
# ImageNet class labels
weights = models.ResNet18_Weights.DEFAULT
preprocess = weights.transforms()
from PIL import Image
img = Image.open("photo.jpg")
batch = preprocess(img).unsqueeze(0)
with torch.no_grad():
prediction = model(batch).squeeze(0)
probabilities = torch.nn.functional.softmax(prediction, dim=0)
top5 = probabilities.topk(5)
for prob, idx in zip(top5.values, top5.indices):
print(f"{weights.meta['categories'][idx]}: {prob:.2%}")
Transfer Learning — Custom Dataset
Train on your own images using a pre-trained backbone:
NUM_CLASSES = 10
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
# Freeze all layers
for param in model.parameters():
param.requires_grad = False
# Replace final layer
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
# Only train the new layer initially
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
Custom Dataset Class
from torch.utils.data import Dataset
from PIL import Image
import os
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.classes = sorted(os.listdir(root_dir))
self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
self.samples = []
for cls in self.classes:
cls_dir = os.path.join(root_dir, cls)
for fname in os.listdir(cls_dir):
if fname.lower().endswith((".jpg", ".png", ".jpeg")):
self.samples.append((os.path.join(cls_dir, fname), self.class_to_idx[cls]))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
path, label = self.samples[idx]
image = Image.open(path).convert("RGB")
if self.transform:
image = self.transform(image)
return image, label
Expected directory structure:
data/train/
├── cats/
│ ├── cat1.jpg
│ └── cat2.jpg
└── dogs/
├── dog1.jpg
└── dog2.jpg
Fine-Tuning — Unfreeze Layers
After training the head, unfreeze deeper layers:
for param in model.layer4.parameters():
param.requires_grad = True
optimizer = torch.optim.Adam([
{"params": model.fc.parameters(), "lr": 1e-3},
{"params": model.layer4.parameters(), "lr": 1e-4},
])
Data Augmentation Strategies
| Transform | Purpose |
|---|---|
RandomHorizontalFlip |
Mirror images |
RandomRotation |
Rotation invariance |
ColorJitter |
Lighting variation |
RandomCrop |
Scale invariance |
Normalize |
Match pre-trained stats |
train_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
Evaluation
def evaluate(model, loader, device):
model.eval()
correct = total = 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return correct / total
Related Chapters
- PyTorch Basics
- PyTorch Training
- TensorFlow Training — CNN alternative
- Project: Serverless Image Processor
Computer vision powers face recognition, medical imaging, autonomous vehicles, and quality inspection — transfer learning makes it accessible with modest datasets.