PyTorch model for regression problem with 4 images per sample with time gap between them

  Kiến thức lập trình

I am using a dataset where each sample corresponds to 4 images taken at a known delay from each other and each set of 4 images has a target prediction that is a number (not classification). I currently have made the model below but it doesnt give good results at all. any advice ?

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=4, out_channels=8, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(8)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout1 = nn.Dropout(p=0.25)
        
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(16)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout2 = nn.Dropout(p=0.25)
        
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(32)
        self.pool3 = nn.MaxPool2d(kernel_size=5, stride=2)
        self.dropout3 = nn.Dropout(p=0.25)
        
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28800, 512)
        self.dropout4 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(512, 1)  # Single output

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = self.dropout1(x)
        
        x = torch.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        x = self.dropout2(x)
        
        x = torch.relu(self.bn3(self.conv3(x)))
        x = self.pool3(x)
        x = self.dropout3(x)
        
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.dropout4(x)
        
        x = self.fc2(x)  # Output layer, no activation function for regression
        return x

Also, the target prediction value is often very small and sometimes much larger such from around 1e-9 to 1e2. i have applied a log scale to the target prediction to reduce this effect to attempt to improve learning but not sure how much it helped.

New contributor

Leon Herrington is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.

LEAVE A COMMENT