Explainable AI (XAI) in remote sensing classification tasks
There is no doubt that deep learning (DL) has a great future ahead of it when it comes to supporting spatial analysis and remote sensing. The possibility to automate many complex tasks enables the potential to reduce the time between a spatial phenomenon occurring and action taken by people living or governing the specific area. With the growing number of applications, the criticism of neural networks in remote sensing grows. It shouldn't be surprising. State of the geographical space whose appearance and observed phenomena can be explained in geographical, biological, chemical, fusion, and social contexts. It is considered, however, that it is only partially possible when a neural network is involved, which remains a black box.
The inability to interpret the outcome of the DL model is far from comfortable when it comes to activities that involve the environment, infrastructure, and people.
What can be done to understand the neural network or a machine learning model in general? How can we get an insight into the complex and intricate process of DL model decision making? In simple words, it isn't easy. Only some machine learning models support explainability by design. Random forest, linear regression, and XGBoost are some examples. The situation is much more complicated but possible regarding vanilla neural networks. Fortunately, XAI (explainable artificial intelligence) can help with this case.
In this blog post, you will find an example of applying the basic XAI technique (SHAP) to interpret the result of the remote sensing imagery classifier.
Example
We will use AID: A scene classification dataset, to train a deep convolutional neural network (DCNN). The dataset contains 30 different scene classes with around 200 to 400 samples of size 600px x 600 px in each class. For this example, we will resize the images to 256px x 256px. Below you can find an image with some images randomly sampled from the AID dataset.
AID: A scene classification dataset sample.
Although the set is small (10k scenes), it is possible to train a DCNN on it. Two parts of the model are essential to prepare the AID classifier. The first is composed of stacked convolutional layers used during spatial feature extraction. The other one is the classification head built from multiple fully connected layers. The last layer has 30 neurons corresponding to the number of our classes. Nothing fancy, but it will do the job.
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Net [32, 30] --
├─Conv2d: 1-1 [32, 32, 252, 252] 2,432
├─MaxPool2d: 1-2 [32, 32, 126, 126] --
├─Conv2d: 1-3 [32, 64, 124, 124] 18,496
├─MaxPool2d: 1-4 [32, 64, 62, 62] --
├─Conv2d: 1-5 [32, 128, 60, 60] 73,856
├─MaxPool2d: 1-6 [32, 128, 30, 30] --
├─Linear: 1-7 [32, 256] 29,491,456
├─Linear: 1-10 [32, 128] 32,896
├─Linear: 1-11 [32, 30] 3,870
==========================================================================================
Total params: 29,623,006
Trainable params: 29,623,006
Non-trainable params: 0
Total mult-adds (G): 23.50
==========================================================================================
Input size (MB): 25.17
Forward/backward pass size (MB): 890.22
Params size (MB): 118.49
Estimated Total Size (MB): 1033.87
==========================================================================================
Accuracy equal to 0.53 and 0.517 f1-score for discriminating 30 classes is acceptable (validation subset; 10% split), but in this case, the results are far from usable in a production or scientific scenario. Fortunately, we have different goals than tuning the model.
Multiple misclassifications are occurring. Some of the instances are hard to distinguish. We must know what errors we are dealing with and the cause of their creation to make the model better or to be able to describe its limitations.
Before applying XAI to the model, we should visit the confusion matrix.
Confusion matrix
Not bad, not terrible. The network can learn important information from the RGB images. Some cases are easy to solve: forests, meadows, and stadiums have distinct spatial compositions. On the other hand, some complex cases would cause problems even for a seasoned specialist in remote sensing. Of course, we could work on the issue further and get a better classifier. But let's stop and analyze what we already have.
XAI
From the point of view of aerial photographic and satellite image interpretation, it is essential to know what spatial features of a given scene were taken into account by the neural network when assigning a class. This information is necessary for the person responsible for the acceptance tests. With experience in analyzing remote sensing imagery and knowledge of the geographic space recorded on them, the expert can show the data science and machine learning teams the strengths and weaknesses of the DL model.
From my experience, utilizing the SHAP library is the quickest and most reliable way of attaching XAI capabilities to a deep learning project. Just run pip install shap or mamba install shap, and you're good to go. In case you use conda, grab this paper and read about interpreting model predictions while your package manager solves the environment ;)
The SHAP framework identifies the class of additive feature importance methods (which includes six previous methods) and shows there is a unique solution in this class that adheres to desirable properties [Lundberg; Lee; 2017].
What we want to achieve using SHAP is to plot the visualization of our test samples with information regarding the areas of the analyzed scenes that contributed the most to the neural network output. Using such an approach, we could gain additional insight into our deep learning model and try to understand its operations.
To use the library, we must first select scenes that will be utilized during the explanation. You should avoid selecting the whole test dataset because this will dramatically increase computation time. I am used to starting at random to narrow the selection to interesting cases eventually. Please note that using SHAP with PyTorch can be tricky. It's mainly because of the tensor channel's first input format. There are multiple ways to handle it. We can transpose the NumPy array and adjust it later.
# val_ds - torch Dataset
n_samples = 10
xai_subset = torch.utils.data.Subset(val_ds, np.random.choice(len(val_ds), n_samples, replace=False))
xai_arr = np.array([x[0].numpy() for x in xai_subset])
Next, we need to define the prediction function. The function will utilize our previously trained DCNN model.
# model - trained model
def predict(x) -> torch.Tensor:
x = x.transpose(0, 3, 1, 2)
x = torch.Tensor(x)
x = x.cuda()
return model.cuda()(x)
When we are done, we can construct an Explainer. We will use a masker to run partition SHAP. The max_evals variable can be adjusted and will influence the resolution of the explanation. A higher value will yield a higher-resolution image and increase computation time. At the same time, we will limit the output to only the four most probable classes.
#ds.classes - AID dataset class names
top_k_classes = 4
masker_blur = shap.maskers.Image('blur(256,256)', (256, 256, 3))
explainer = shap.Explainer(predict, masker_blur, output_names=ds.classes)
shap_values = explainer(xai_arr, max_evals=50000, batch_size=128, outputs=shap.Explanation.argsort.flip[:top_k_classes])
Finally, we can plot the visualization. For convenience, we added ground-truth labels to the output image. There will be five columns generated. The first one is the input scene. The others are the possible classes from most to least probable. Red areas indicated an increase in the probability of a specified class, and blue areas decreased the probability.
xai_labels = np.array([ds.classes[x[1]] for x in xai_subset])
shap.image_plot(shap_values, true_labels=xai_labels.tolist())
SHAP image plot.
There are some issues with the model, but now we can determine the cause of the malfunction. Some classes are challenging to recognize. Perhaps residential areas should be merged into one class. The other scenes need a boost in spectral resolution. Adding a thermal channel will distinguish airports from railway stations, and using near-infrared can help recognize farmlands. We could also check the meadow images for issues related to imagery acquisition or provide additional data from other seasons to make the model more robust.
After applying SHAP, you are left with NumPy arrays, so further processing is straightforward. Visualizing multichannel images or temporal data is more demanding but doable. If you want to work in your favorite GIS software, SHAP values can be converted to georeferenced rasters.
Industrial parking, bridges, SHAP value
Summary
Cooperation between various parties is essential if the final DL solution is to be valid and useful. What makes such collaboration possible is the possibility of explaining the model and interpreting its outcome. In this case, you can use XAI techniques as adapters between engineers and scientists, allowing them to discuss and introduce a feedback loop to improve outcomes after each iteration.