How can I prevent overfitting in a convolutional neural network while training on a small?

Maitrik
Updated 1 day ago in

I’m training a CNN on a relatively small image dataset, and the training accuracy quickly reaches near 100%, but validation accuracy stagnates and then drops. I suspect overfitting is the issue.

Here’s a simplified version of my training code in PyTorch:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Dataset
train_dataset = datasets.ImageFolder('data/train', transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

# Simple CNN
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16*32*32, 10)

def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = x.view(-1, 16*32*32)
x = self.fc1(x)
return x

model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
 

I’ve read about techniques like data augmentation, dropout, and weight regularization, but I’m not sure how to integrate them effectively.

What strategies or best practices would you recommend for reducing overfitting in small datasets while training CNNs?

  • 0
  • 16
  • 1 day ago
 
Loading more replies