Rapid Network Adaptation:
Learning to Adapt Neural Networks
Using Test-Time Feedback


ICCV 2023


Teresa Yeo, Oğuzhan Fatih Kar, Zahra Sodagar, Amir Zamir




Paper

The paper and supplementary material.

PDF

Models

Download the trained models and baselines.

Pretrained Models

Code

The code for training and testing models and baselines.

Get started

Overview

An overview video (20 mins) explaining our method is given below. Please turn on your speaker for narration.


Adaptation method

Description of our method for performing test-time adaptation.

Test-time signals

Examples of test-time signals that we employed.

Sample results

Results on adaptation with RNA vs TTO and on different tasks.



Quick Summary

Neural networks are found to be unreliable against distribution shifts. Examples of such shifts include blur due to camera motion, object occlusions, changes in weather conditions and lighting. Dealing with such shifts is difficult as they are numerous and unpredictable. Therefore, training-time strategies that attempt to take anticipatory measures for every possible shift (e.g., augmenting the training data or changing the architecture with corresponding robustness inductive biases) have inherent limitations. This is the main motivation behind test-time adaptation methods, which instead aim to adapt to such shifts as they occur. In other words, these methods choose adaptation over anticipation. In this work, we propose a test-time adaptation framework that aims to perform an efficient adaptation of a main network using a feedback signal.

Adaptive vs non-adaptive neural network pipelines. In order to be robust, non-adaptive methods include training-time interventions that anticipate and counter the distribution shifts that will occur at test-time (e.g., via data augmentation). Thus upon encountering an out-of-distribution input, its predictions may collapse.
Adaptive methods create a closed loop and use an adaptation signal at test-time. The adaptation signal is a quantity that can be computed at test-time from the environment. \(h_\phi\) acts as a "controller" by taking in an error feedback, computed from the adaptation signal and model predictions, to adapt \(f_\theta\) accordingly. It can be implemented as a (i) standard optimizer (e.g., using SGD) or (ii) neural network. The former is equivalent to test-time optimization (TTO), while the latter aims to amortize the optimization process, by training a controller network to adapt \(f_\theta\) - thus, it can be more efficient and flexible. In this work, we study the latter approach and show its efficiency and flexibility.







An adaptive system is one that can respond to changes in its environment. Concretely, it is a system that can acquire information to characterize such changes, e.g., via an adaptation signal that provides an error feedback, and make modifications that would result in a reduction of this error.

The methods for performing the adaptation of the system range from gradient-based updates, e.g. just using SGD to fine-tune the parameters (Sun et al.,Wang et al.,Gandelsman et al.), to the more efficient semi-amortized (Zintgraf et al.,Triantafillou et al.) and amortized approaches (Vinyals et al.,Oreshkin et al.,Requeima et al.). As amortization methods train a controller network to substitute the explicit optimization process, they only require forward passes at test-time. Thus, they are computationally efficient. Gradient-based approaches, e.g., TTO, can be powerful adaptation methods when the test-time signal is robust and well-suited for the task. However, they are inefficient and also have the risk of overfitting and the need for carefully tuned optimization hyperparameters (Boudiaf et al.). We will discuss this in more detail here. In this work, we focus on an amortization-based approach that we call Rapid Network Adaptation or RNA for short.

Below we show how we adapt a model \(f_\theta\). \(x\) is the input image, and \(f_\theta(x)\) the corresponding prediction. We first freeze the parameters of \(f_\theta\) and insert several FiLM layers into \(f_\theta\). We then train \(h_\phi\) to take in \(z\), the adaptation signal, and \(f_\theta(x)\) to predict the parameters of these FiLM layers. This results in an adapted model \(f_{\hat{\theta}}\) and the improved predictions, \(f_{\hat{\theta}}(x)\). We will show that adaptation with \(h_\phi\) results in a closed loop system that is flexible and is able to generalize to unseen shifts.

Architecture of RNA.




While developing adaptation signals is not the main focus of this study and is independent of the RNA method, we still need to choose some for experimentation. Existing test-time adaptation signals, or proxies, in the literature include prediction entropy (Wang et al.), spatial autoencoding (Gandelsman et al.) and self-supervised tasks like rotation prediction (Sun et al.), contrastive (Liu et al.) or clustering (Boudiaf et al.) objectives. The more aligned the adaptation signal is to target task, the better the performance on the target task (Sun et al., Liu et al.). More importantly, a poor signal can cause the adaptation to fail silently (Boudiaf et al.,Gandelsman et al.).

The plot below shows how the original loss on the target task changes as different proxy losses from the literature, i.e. entropy, consistency between different middle domains are minimized. In all cases, the proxy loss decreases, however, the improvement in the target loss varies. Thus, successful optimization of existing proxy losses does not necessarily lead to better performance on the target task.


Adaptation using different signals. Not all improvements in proxy loss translates into improving the target task's performance. We show the results of adapting a pre-trained depth estimation model to a defocus blur corruption by optimizing different adaptation signals: prediction entropy, a self-supervised task (sobel edge prediction error), and sparse depth obtained from SFM. The plots show how the \(\ell_1\) target error with respect to ground-truth depth (green, left axis) changes as the proxy losses (blue, right axis) are optimized (shaded regions represent the 95% confidence intervals across multiple runs of SGD with different learning rates). Only adaptation with the sparse depth (SFM) proxy leads to a reduction of the target error. This signifies the importance of employing proper signals in an adaptation framework.


We show some examples of test-time adaptation signals for several geometric and semantic tasks below. Our focus is not on providing an extensive list of adaptation signals, but rather on using practical ones for experimenting with RNA as well as demonstrating the benefits of using signals that are rooted in the known structure of the world and the task in hand. For example, geometric computer vision tasks naturally follow the multi-view geometry constraints, thus making that a proper candidate for approximating the test-time error, and consequently, an informative adaptation signal.


Examples of employed test-time adaptation signals. We use a range of adaptation signals in our experiments. These are practical to obtain and yield better performance compared to other proxies. In the left plot, for depth and optical flow estimation, we use sparse depth and optical flow via SFM. In the middle, for classification, for each test image, we perform \(k\)-NN retrieval to get \(k\) training images. Each of these retrieved image has a one hot label associated with it, thus, combining them gives us a coarse label that we use as our adaptation signal. Finally, for semantic segmentation, after performing \(k\)-NN as we did for classification, we get a pseudo-labelled segmentation mask for each of these images. The features for each patch in the test image and the retrieved images are matched. The top matches are used as sparse supervision.


To perform adaptation at test-time, we first compute the adaptation signal as described above. The computed signal and the prediction from the model before adaptation, \(f_\theta\), are concatenated to form the error feedback. This error feedback is then passed as inputs to \(h_\phi\) (see the figure here). These adaptation signals are practical for real-world use but they are also imperfect i.e., the sparse depth points do not correspond to the ground truth values. Thus, to perform controlled experiments and separate the performance of RNA and adaptation signals, we also provide experiments using ideal adaptation signals, e.g., masked ground truth. In the real world, these ideal signals can come from sensors like Radar.

Here are some qualitative results. Zoom in to see the fine-grained details. See the paper for full details.

Key takeaways:

  •    We show that RNA is able to amortize the optimization process, thus, making it orders of magnitude faster than TTO. See the video here
  •    It is flexible - it can be applied to different architectures (see the results here) and used with different adaptation signals (see the supplementary Table 1).
  •    This allows us to outperform the baselines on:
  •   different distribution shifts (common corruptions, 3D common corruptions, cross datasets),
      tasks (depth, optical flow, dense 3D reconstruction, semantic segmentation, image classification),
      and datasets (Taskonomy, Replica, ImageNet, COCO, ScanNet, Hypersim). See the following section for results.



Adapting with RNA vs TTO

Here is a summary of our observations from adapting with RNA vs TTO. TTO represents the approach of closed-loop adaptation using the adaptation signal but without benefiting from any amortization (the adaptation process is fixed to be standard SGD). These observations hold across different tasks.

  •    RNA is efficient. RNA only requires forward passes at test-time, thus, it is orders of magnitude faster than TTO. Furthermore, it is able to attain comparable performance to TTO. See the video here.
  •    RNA's predictions are sharper than TTO. RNA benefits from learning to adapt the network. Thus, it can learn the relationship between the noisy and incomplete adaptation signal to the target objective, dense depth prediction. This is in contrast to TTO, that uses SGD in a vanilla way. See the figure here.
  •    RNA generalizes to unseen shifts. Although RNA was only trained on cleaned data, it is able to outperform TTO at low severities. However, the performance gap against TTO narrows at high severities, as TTO is exposed to corruptions at test-time. This is demonstrated in all the subsequent results.

The video below demonstrates the efficiency of RNA. It shows an image that has been corrupted with gaussian noise. The test-time signal is noisy sparse depth from SFM and has been overlayed over the input image. The predictions at iteration 0 are the same for all methods as this is before any adaptation. Note that RNA is able to attain an improved prediction after a single forward pass. The top right plot shows how the l1 error changes with iteration. RNA, show in green, significantly reduces error.



Adapting with increasing supervision

The video below demonstrates the performance of RNA with increasing supervision. It shows an image that has been corrupted with gaussian noise. The test-time signal is click annotations and has been overlayed over the image (2nd row, 1st col). RNA is able to attain improved predictions with as few annotations.



Results on Different Tasks

We now show evaluations for various target tasks and adaptation signals.


Lets first look at qualitative results of RNA vs the baselines for semantic segmentation on random query images on COCO-CC (left) and depth on images from ScanNet, Taskonomy-3DCC and Replica-CC (right). The predictions with adaptation signals described above are shown in the last two rows. They are noticeably more accurate compared to the baselines. Comparing TTO and RNA, RNA's predictions are more accurate for segmentation, and sharper than TTO for depth (see the ellipses) while being significantly faster.

Adaptation results for semantic segmentation and depth. For semantic segmentation, we use 15 pixel annotations per class. For Taskonomy-3DCC, we use sparse depth with 0.05% valid pixels (30 pixels per image). For ScanNet and Replica-CC, the adaptation signal is sparse depth measurements from SFM with similar sparsity ratios to Taskonomy-3DCC.


We also demonstrate the effectiveness of RNA on dense 3D reconstruction. The goal is to reconstruct a 3D pointcloud of an apartment given a sequence of corrupted images in it. The depth predictions from the pre-adaptation baseline (2nd column) has poor predictions and results in a pointcloud that has large artifacts and frequent discontinuties in the scene geometry. To perform adaptation, we compute the noisy sparse depth from SFM and use it to adapt the depth model. The predictions from the adapted models are then used in the backprojection to attain a 3D point cloud. RNA and TTO both can significantly correct such errors and recover a 3D consistent pointcloud. RNA is able to achieve this orders magnitude faster than TTO.

Adaptation results on 3D reconstruction. Camera poses and 3D keypoints are first obtained from SFM. They are then used to adapt monocular depth predictions for each image, which are then backprojected into a 3D pointcloud.


We also have supportive results on ImageNet classification. The table on the right shows the results from using 45-coarse labels on ImageNet-{C,3DCC,V2}. This corresponds to 22x coarser supervision compared to the 1000 classes that we are evaluating on. See the paper Section 4.1 for how these coarse labels are computed.
TENT seems to have notable improvements in performance under corruptions for classification, unlike for semantic segmentation and depth. Using coarse supervision results in even better performance, about a further 5 pp reduction in error. Furthermore, on uncorrupted data, i.e. clean, and ImageNet-V2, RNA gives roughly 10 pp improvement in performance compared to TTO. Thus, coarse supervision provides a useful signal for adaptation while requiring much less effort than full annotation. We also have results on adaptation using coarse labels computed using DINO pre-trained features, see the paper Table 3 for results.

Adaptation Signal Dataset Clean IN-C IN-3DCC IN-V2 Rel. Runtime
- Pre-adaptation Baseline 23.9 61.7 55.0 37.2 1.0
Entropy TENT 24.7 46.2 47.1 37.1 5.5
Coarse labels
(wordnet)
Densification 95.5 95.5 95.5 95.5 -
TTO (Online) 24.7 40.6 42.9 36.8 5.7
RNA (frozen \(f\)) 16.7 41.2 40.4 25.5 1.4

Quantitative adaptation results on on ImageNet (IN) classification task. We report average error (%) for 1000-way classification task over all corruptions and severities.



Additional results




RNA is not specific to the choice of architecture of \(f\). In the table on the right, we show the results for RNA applied to the Dense Prediction Transformer (DPT) (Ranftl et al.) for depth estimation (left) on Taskonomy dataset, and ConvNext (Liu et al.) for ImageNet classification (right). In both cases RNA achieves better performance and runtime than TTO.

Task (Arch.) Depth (DPT) Classification (ConvNext)
Shift Clean CC Rel. Runtime Clean IN-C Rel. Runtime
Pre-adaptation Baseline 2.2 3.8 1.0 18.1 43.0 1.0
TTO (Online) 1.8 2.6 13.9 17.8 41.4 11.0
RNA (frozen \(f\)) 1.1 1.6 1.0 14.3 38.0 1.1

RNA works across different architectures. Lower is better. \(\ell_1\) errors for depth estimation are multiplied by 100 for readability.



We also have additional results showing the following:

  •    RNA still outperforms after controlling for number of parameters. See Table 2 in the supplementary.
  •    RNA can also work with adaptation signals used in the literature. See Table 1 in the supplementary.
  •    Different implementations of RNA. E.g., other than predicting the parameters of FiLM layers, we also implemented RNA as a HyperNetwork. See Sec 2.2 in the supplementary.



Paper

Rapid Network Adaptation: Learning to Adapt Neural Networks Using Test-Time Feedback.
Yeo, Kar, Sodagar, Zamir.
ICCV 2023






Team

Teresa Yeo

EPFL

Oğuzhan Fatih Kar

EPFL

Zahra Sodagar

EPFL, Sharif University

Amir Zamir

EPFL