Multiclass UNet with n-dimensional satellite images

  Kiến thức lập trình

I’m trying to use a UNet in Pytorch to extract prediction masks from multidimensional (8 band) satellite images. I’m having trouble getting the prediction masks to look somewhat expected/coherent. I’m not sure if the issue is the way my training data is formatted, my training code, or the code I’m using to make predictions. My suspicion is that it is the way my training data is being fed to the model. I have 8 band satellite images and single band masks with values ranging 0-n number of classes with 0 being background and 1-n being target labels like this:

enter image description here

With the image shape being (8, 512, 512) and the mask shape being (512, 512) in the case of the single channel example, (512, 512, 8) in the OHE case, and (512, 512, 3) in the stacked case.

Some masks may contain all class labels, some may only have a couple or be background labels only. I’ve tried using these single channel masks, I’ve also converted them into 3 channel masks with the first channel being all the labels for a given image, and I’ve also tried one hot encoding them such that each mask is 0-n dimensions and each channel a different label with binary 0-1 for background/target.

In every case, no matter which way I format the training data, the results from training end up being either all black, all white, or some grid effect like this:

enter image description here

Is there an ideal way this data should formatted for training/prediction or is there something I’m doing incorrectly which is resulting in these bad prediction masks?

For the sake of not posting hundreds of lines of code, here are some generalized snippets of what I used to try stacking/OHE the masks, train, and predict:

Mask manipulation:

import numpy as np
from PIL import Image
mask = np.array(Image.open(mask_path))
if stack:
    zeros = np.zeros((mask.shape[0], mask.shape[1]))
    mask = np.transpose(np.array([mask, zeros, zeros]), (1, 2, 0)).astype(np.uint8)
if onehot:
    one_hot_mask = np.zeros((mask.shape[0], mask.shape[1], self.num_classes))
    label_values = list(np.unique(mask))
    for i in range(0, self.num_classes):
        if i not in label_values:
            one_hot_mask[:, :, i] = 0
        else:
            if stack:
                one_hot_mask[:, :, i][mask[:, :, 0] == i] = 1
            else:
                one_hot_mask[:, :, i][mask[:, :] == i] = 1

Train

import torch
import torch.nn as nn
import torch.optim as optim

num_bands = 8
num_classes = 8 # it just so happened this dataset had the same number of classes as input bands/channels, this isn't always the case however
epochs=5
learning_rate = 0.001
weight_decay = 0

model = UNet(n_channels=num_bands, n_classes=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
loss_fn = nn.CrossEntropyLoss() if num_classes > 1 else nn.BCEWithLogitsLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for epoch in range(epochs):
    loop = tqdm(enumerate(train_loader), total=len(train_loader))
    for batch_idx, (data, targets) in loop:
        data = data.float().to(device)
        targets = targets.long().to(device)
        predictions = model(data)
        loss = loss_fn(predictions, targets)
        optimizer.zero_grad()
        loop.set_postfix(loss=loss.item())
    checkpoint_name = os.path.join(model_dir, f"model_{epoch}.pt")
    torch.save(model.state_dict(), checkpoint_name)

Predict

import numpy as np
from skimage import io
image = np.array(io.imread(image_path))

tensor = ToTensor()(image)
batch_t = torch.unsqueeze(tensor, 0).to(device)
preds = model(batch_t)
softmax = torch.nn.Softmax(dim=1)
preds = torch.argmax(softmax(model(preds)),axis=1).cpu()
preds = np.array(preds[0,:,:])
plt.imshow(preds, cmap='tab20')

LEAVE A COMMENT