How to make your foundation model equivariant

Incorrect segment anything foundation model and correct segment anything model with canonicalization network

Authors: Arnab Kumar Mondal, Siba Smarak Panigrahi, and Sai Rajeswar

Deep learning has experienced tremendous growth over the past decade. Still, as we strive for more nuanced understanding and performance improvements of it, a question emerges: How do we ensure our models understand data transformations?

Enter equivariance, which can help our networks maintain consistent behavior when presented with data transformations. With the rise of large pretrained models, how do we make them equivariant without changing their architecture or retraining the foundation model from scratch with data augmentation?

What is equivariance?

An equivariant network1,2,3 is a deep neural network that maintains consistent behavior when input data undergoes a transformation, such as rotation, scaling, or translation. In simpler terms, if we rotate an image of a cat, an equivariant network would still recognize it as a cat.

The beauty of this is that such networks lead to more accurate, robust predictions and need fewer samples to train. This is great in theory but difficult to implement, especially for large foundation models whose equivariant counterparts aren’t trivial or are expensive to retrain from scratch.

Since foundation models4 aren’t naturally equivariant and usually don't handle transformations well, this problem becomes pertinent. (See Figures 1a and 1b.)

Text extraction from an image of OpenAI’s GPT-4 launch tweet (correct)

Figure 1a. Text extraction from an image of OpenAI’s GPT-4 launch tweet (correct)

Text extraction from an inverted image of OpenAI’s GPT-4 launch tweet (incorrect)

Figure 1b. Text extraction from an inverted image of OpenAI’s GPT-4 launch tweet (incorrect)

Canonicalization: Decoupling equivariance from architecture

A recent alternative proposed by Kaba, et al.,5 suggests that instead of changing the main network, it may be necessary first to transform the input data into a standard format, also known as canonical form. This way, the primary network can operate on this standardized format, ensuring consistency.

This process involves two main networks: the canonicalization network, which standardizes the input, and the prediction network, which predicts based on the standardized input. In this particular formulation, achieving equivariance requires only that the canonicalization process itself is equivariant.

The beauty of this approach lies in how the canonicalization network separates the equivariance requirement from the core prediction network architecture. This means there’s flexibility to employ any powerful pretrained large neural network for the main prediction task.

Sound straightforward? It has a hitch. The main challenge is ensuring the canonicalization network “plays nice” with the prediction network. This becomes more important when the prediction network is pretrained on a certain dataset.

For instance, if the canonicalization network transforms all images to be upside-down, but our prediction network wasn't trained on upside-down images, the whole system falls apart. So, it's vital that these two networks are in sync.

Learning to predict the correct orientation

The canonicalization function must be designed not just to transform data, but also to do so while being aware of how our prediction model was initially trained. The key is ensuring that the data being transformed (or standardized) aligns with what the pretrained prediction model expects.

Mathematically, we want to bring the predicted out-of-distribution orientations to the distribution of orientations the pretrained prediction network has already seen.

The canonicalization function outputs a distribution over image orientations used to canonicalize the input image.

Figure 2. Training and inference with the canonicalization prior: The canonicalization function outputs a distribution over image orientations used to canonicalize the input image. Additionally, during training, this predicted distribution is regularized to match the orientations seen in the dataset.

Enter the canonicalization prior

In simple terms, the canonicalization prior is a guiding force ensuring that our canonicalization function exhibits behavior and produces output that the pretrained prediction network would expect. We use the idea that our data can provide hints on the “typical” transformations it undergoes.

By encoding this into a prior, we can guide our canonicalization function to produce transformed data that's not just standardized, but also aligned with what the prediction network was trained on.

While it’s mathematical and intricate, this entire process can be boiled down to ensuring that the large pretrained prediction network always looks at in-distribution samples. This results in a robust model that can handle varied transformations of the input data, producing accurate predictions every time.

We show that this idea can scale to large foundation models like Segment Anything Model (SAM)6 and make them robust to rotations, with a nominal increase in the number of parameters and inference speed.

Predicted masks from the SAM showcasing both the original model and our proposed equivariant adaptation for 90-degree counterclockwise-rotated input images taken from the COCO 2017 dataset

Figure 3. Predicted masks from the SAM6 showcasing both the original model and our proposed equivariant adaptation for 90-degree counterclockwise-rotated input images taken from the COCO 2017 dataset.7 Our method makes SAM equivariant to the group of 90-degree rotations while requiring only 0.3% more parameters and increasing the inference time by just 7.3%.
In the ever-evolving world of AI and deep learning, it’s critical to ensure models are robust and aware of symmetries. By learning to transform our input data so that it’s in the appropriate orientation for pretrained models, we can create large-scale models that are powerful and aware of data transformations, bringing us a step closer to AI systems that understand the world as we do.

As research into scaling continues, the fusion of large foundation models with equivariant adaptation techniques such as this one has the potential to be a fundamental approach to enhancing the consistency and reliability of AI systems.

Current issues and future directions

Equivariant canonicalization network constraints: Our current methodology leans heavily on the equivariant canonicalization network. This confines the range of transformations we can adapt to. Alternative solutions, such as the optimization techniques hinted at in Kaba, et al.,5 could obviate the need for a dedicated equivariant network. However, the computational efficiency of such solutions remains unknown.

Incorporating continuous rotations: We’ve faced challenges when integrating the canonicalization prior with steerable8 canonicalization networks for images. Successfully navigating these would allow us to assimilate continuous rotations into our image framework, thus enhancing its versatility.

Expressivity versus computational cost: The overall performance of our model is intrinsically tied to the canonicalization function's expressivity. Exploration of computationally efficient architectures for canonicalization functions that can maximize performance gains is another open research direction.

Domain-driven canonicalization adaptation: Another promising direction is learning to canonicalize using specific samples from target domains. This could potentially circumvent the need for prior knowledge of transformation groups, bridging the realms of equivariant deep learning and domain adaptation.

References

  1. Taco Cohen and Max Welling. Group equivariant convolutional networks. In Proceedings of The 33rd International Conference on Machine Learning, volume 48 of Proceedings of Machine Learning Research, pages 2990–2999, New York, New York, USA, 20–22 Jun 2016. PMLR
  2. Daniel Worrall and Max Welling. Deep scale-spaces: Equivariance over scale. Advances in Neural Information Processing Systems, 32, 2019.
  3. Michael M Bronstein, Joan Bruna, Taco Cohen, and Petar Velickovi ́c. Geometric deep learning: Grids, groups, graphs, geodesics, and gauges. arXiv preprint arXiv:2104.13478, 2021.
  4. Rishi Bommasani, Drew A Hudson, Ehsan Adeli, Russ Altman, Simran Arora, Sydney von Arx, Michael S Bernstein, Jeannette Bohg, Antoine Bosselut, Emma Brunskill, et al. On the opportunities and risks of foundation models. arXiv preprint arXiv:2108.07258, 2021.
  5. Sékou-Oumar Kaba, Arnab Kumar Mondal, Yan Zhang, Yoshua Bengio, and Siamak Ravanbakhsh. Equivariance with learned canonicalization functions. In 40th International Conference on Machine Learning, 2023.
  6. Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alexander C. Berg, Wan-Yen Lo, Piotr Dollár, and Ross Girshick. Segment anything. In International Conference on Computer Vision, 2023.
  7. Tsung-Yi Lin, Michael Maire, Serge Belongie, James Hays, Pietro Perona, Deva Ramanan, Piotr Dollár, and C Lawrence Zitnick. Microsoft coco: Common objects in context. In Computer Vision–ECCV 2014.
  8. Maurice Weiler and Gabriele Cesa. General e (2)-equivariant steerable cnns. Advances in Neural Information Processing Systems, 32, 2019.

Citation

More details can be found in our NeuRIPS 2023 paper, “Equivariant Adaptation of Large Pre-Trained Models.” For citations, please use the following:

@inproceedings{mondal2023equivariant,
title={Equivariant Adaptation of Large Pre-Trained Models},
author={Arnab Kumar Mondal and Siba Smarak Panigrahi and S{\'e}kou-Oumar Kaba and Sai Rajeswar and Siamak Ravanbakhsh},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://arxiv.org/pdf/2310.01647.pdf}\\ }

Find out more about ServiceNow Research.