key: cord-0173596-k4atuune authors: Rame, Alexandre; Dancette, Corentin; Cord, Matthieu title: Fishr: Invariant Gradient Variances for Out-of-distribution Generalization date: 2021-09-07 journal: nan DOI: nan sha: ce75c3136353d5dae843bc490dcc783c5d8f663f doc_id: 173596 cord_uid: k4atuune Learning robust models that generalize well under changes in the data distribution is critical for real-world applications. To this end, there has been a growing surge of interest to learn simultaneously from multiple training domains -- while enforcing different types of invariance across those domains. Yet, all existing approaches fail to show systematic benefits under controlled evaluation protocols. In this paper, we introduce a new regularization -- named Fishr -- that enforces domain invariance in the space of the gradients of the loss: specifically, the domain-level variances of gradients are matched across training domains. Our approach is based on the close relations between the gradient covariance, the Fisher Information and the Hessian of the loss: in particular, we show that Fishr eventually aligns the domain-level loss landscapes locally around the final weights. Extensive experiments demonstrate the effectiveness of Fishr for out-of-distribution generalization. Notably, Fishr improves the state of the art on the DomainBed benchmark and performs consistently better than Empirical Risk Minimization. The code is released at https://github.com/alexrame/fishr. The success of deep neural networks in supervised learning (Krizhevsky et al., 2012) relies on the crucial assumption that the train and test data distributions are identical. In particular, the tendency of networks to rely on simple features (Kalimeris et al., 2019; Valle-Perez et al., 2019; Geirhos et al., 2020) is generally a desirable behavior reflecting Occam's razor. However, in case of distribution shift, this simplicity bias deteriorates performance since more complex features are needed (Tenenbaum, 2018; Shah et al., 2020) . For example, in the fight against Covid-19, most of the deep learning methods developed to detect coronavirus from chest scans were shown useless for clinical use (De-Grave et al., 2021; Roberts et al., 2021) : indeed, networks exploited simple bias in the training datasets such as patients' age or body position rather than 'truly' analyzing medical pathologies. To better generalize under distribution shifts, most works such as Blanchard et al. (2011) or Muandet et al. (2013) assume that the training data is divided into different training domains in which there is a constant underlying causal mechanism (Peters et al., 2016) . To remove the domain-dependent explanations, different invariance criteria across those training domains have been proposed. (Ganin et al., 2016; Sun et al., 2016; Sun & Saenko, 2016) enforce similar feature distributions, others (Arjovsky et al., 2019; Krueger et al., 2021) force the classifier to be simultaneously optimal across all domains. Yet, despite the popularity of this research topic, none of these methods perform significantly better than the classical Empirical Risk Minimization (ERM) when applied with controlled model selection and restricted hyperparameter search (Gulrajani & Lopez-Paz, 2021; Ye et al., 2021b) . These failures motivate the need for new ideas. To foster the emergence of a shared mechanism with consistent generalization properties, our intuition is that learning should progress consistently with similarly diverse feedbacks across domains. Besides, the learning procedure of deep neural networks is dictated by the distribution of the gradients with respect to the network weights (Yin et al., 2018; Sankararaman et al., 2020) usually backpropagated in the network during gradient descent. Thus, we seek distributional invariance across domains in the gradient space: domain-level gradients should be similar, not only in average direction, but most importantly in statistics such as covariance and dispersion. in blue). We will show how this regularization during the learning of θ improves the out-of-distribution generalization properties by aligning locally the domain-level loss landscapes at convergence. In this paper, we propose a gradient-based regularization for out-of-distribution generalization in classification -summarized in Fig. 1 . We match the domainlevel gradient covariances, i.e., the second moment of the gradient distributions. In contrast, previous gradientbased works such as Fish (Shi et al., 2021 ) only match the domain-level gradients means, i.e., the first moment. Moreover, our strategy is also motivated by the close relations between the gradient covariance (Faghri et al., 2020) , the Fisher Information Matrix (Amari, 1998) and the Hessian. This explains the name of our work, Fishr, refining Fish and related to the Fisher Information. Notably, we will study how Fishr forces the model to have similar domain-level Hessians and, as a consequence, promotes consistent explanations -based on the inconsistency formalism from Parascandolo et al. (2021) . To scale our approach and reduce the computational overhead, we justify an approximation that only considers the diagonal of the gradient covariances in the last layer's weights. Ultimately, our regularization goes back to matching across domains the domain-level variances of gradients in the classifier. This is possible at low cost with the BackPACK (Dangel et al., 2020) package. We summarize our contributions as follows: • We introduce the Fishr regularization that brings closer the domain-level gradient variances. • Based on the relation between the gradient covariance, the Fisher Information and the Hessian, we show that Fishr matches domain-level Hessians and improves generalization by reducing inconsistencies across domains. • We propose a simple and scalable implementation. Empirically, we first validate that Fishr tackles distribution shifts on the synthetic Colored MNIST (Arjovsky et al., 2019) . Then, we show that Fishr performs best on the DomainBed benchmark (Gulrajani & Lopez-Paz, 2021) with the 'Oracle' model selection method and third with the 'Training' model selection when compared with state-of-the-art counterparts. Critically, Fishr is the only method to perform better (on VLCS, OfficeHome, TerraIncognita and DomainNet) or similarly (on PACS) than ERM with both selection methods on all 'real' datasets. We first describe our task and provide the notations used along our paper. Then we remind some important related works to understand how our Fishr stands in a rich literature. We study out-of-distribution (OOD) generalization for classification. Our model is a deep neural network (DNN) f θ (parametrized by θ) made of a deep features extractor Φ φ on which we plug a dense linear classifier w ω : f θ = w ω • Φ φ and θ = (φ, ω). In training, we have access to different domains E: for each domain e ∈ E, the dataset D e = x i e , y i e ne i=1 contains n e i.i.d. (input, labels) samples drawn from a domain-dependent probability distribution. Combined together, the datasets {D e } e∈E are of size n = e∈E n e . Our goal is to learn weights θ so that f θ predicts well on a new test domain, unseen in training. As described in Koh et al. (2020) and Ye et al. (2021b) , most common distribution shifts are diversity shifts -where the training and test distributions comprise data from related but distinct domains -or correlation shifts -where the distribution of the covariates at test time differs from the one during training. To generalize well despite these distribution shifts, f θ should ideally capture an invariant mechanism across training domains. Following standard notations, M 2 F denotes the Frobenius norm of matrix M ; v 2 2 denotes the euclidean norm of vector v; 1 is a column vector with all elements equal to 1. The standard Expected Risk Minimization (ERM) (Vapnik, 1999) framework simply minimizes the average empirical risk over all training domains, i.e., 1 f θ x i e , y i e and is the loss, usually the negative log-likelihood. Many approaches try to exploit some external source of knowledge (Xie et al., 2021) , in particular the domain information. As a side note, these partitions may be inferred if not provided (Creager et al., 2021) . Some works explore data augmentations to mix samples from different domains (Wang et al., 2020; Wu et al., 2020) and others re-weight the training samples to favor underrepresented groups (Sagawa et al., 2020a; Zhang et al., 2021 ). Yet, most recent works promote invariance via a regularization criterion and only differ by the choice of the statistics to be matched across training domains. They enforce agreement either (1) in features (2) in predictors or (3) in gradients. First, some approaches aim at extracting domain-invariant features and were extensively studied for unsupervised domain adaptation. The features are usually aligned with adversarial methods (Ganin et al., 2016; Gong et al., 2016; Li et al., 2018b; or with kernel methods (Muandet et al., 2013; Long et al., 2014 ). Yet, the simple covariance matching in CORAL (Sun et al., 2016; Sun & Saenko, 2016) performs best on various tasks for OOD generalization (Gulrajani & Lopez-Paz, 2021) . With Z ij e the j-th dimension of the features extracted by Φ φ for the i-th as in Deep Domain Confusion (DDC) (Tzeng et al., 2014) . Yet, Johansson et al. (2019) and Zhao et al. (2019) show that these approaches are insufficient to guarantee good generalization. Motivated by arguments from causality (Pearl, 2009) and the idea that statistical dependencies are epiphenomena of an underlying causal structure, Invariant Risk Minimization (IRM) (Arjovsky et al., 2019) explains that the predictor should be invariant (Peters et al., 2016; Rojas-Carulla et al., 2018) , i.e., simultaneously optimal across all domains. Among many suggested improvements (Chang et al., 2020; Idnani & Kao, 2020; Teney et al., 2020; Ahmed et al., 2021) , Risk Extrapolation (V-REx) (Krueger et al., 2021) argues that training risks from different domains should be similar and thus penalizes |R A − R B | 2 when E = {A, B}. These ideas have been applied in semi-supervised learning (Li et al., 2021 ). Yet, recent works point out pitfalls of IRM (Javed et al., 2020; Guo et al., 2021; Kamath et al., 2021) , that does not provably work with non-linear data (Rosenfeld et al., 2021) and could not improve over ERM when hyperparameter selection is restricted (Koh et al., 2020; Gulrajani & Lopez-Paz, 2021; Ye et al., 2021b) . A recent line of work promotes agreements across domains between gradients with respect to θ. This strategy helps batches from different tasks to cooperate, and was employed for multitasks (Du et al., 2018; Yu et al., 2020 ), continual (Lopez-Paz & Ranzato, 2017 ), meta (Finn et al., 2017 Zhang et al., 2020) and reinforcement learning. In OOD generalization, (Koyama & Yamaguchi, 2020; Parascandolo et al., 2021; Shi et al., 2021) try to find minimas of the loss shared across domains. Specifically, these works tackle the domain-level expected gradients: When E = {A, B}, IGA (Koyama & Yamaguchi, 2020) minimizes ||g A − g B || 2 2 ; Fish (Shi et al., 2021) increases g A · g B ; AND-mask (Parascandolo et al., 2021) and others (Mansilla et al., 2021; Shahtalebi et al., 2021) update weights only when g A and g B point to the same direction. The main limitation of these gradient-based methods is the per-domain batch averaging of gradients, that removes more granular statistics; in particular, this averaging removes the information from pairwise interactions between gradients from samples in a same domain. In opposition, our new regularization for OOD generalization keeps extra information from individual gradients and matches across domains the domain-level gradient variances. In a nutshell, our work is to IGA (Koyama & Yamaguchi, 2020) and Fish (Shi et al., 2021) in gradient space as CORAL (Sun et al., 2016; Sun & Saenko, 2016) is to DDC (Tzeng et al., 2014) in feature space. The individual gradient g i e = ∇ θ f θ (x i e ), y i e is the first-order derivative for the i-th data example x i e , y i e from domain e ∈ E with respect to the weights θ. Previous methods have matched the gradient means g e = 1 ne ne i=1 g i e for each domain e ∈ E of size |θ|, usually computed for gradient descent during the learning procedure. Leveraging the full matrix G e = [g i e ] ne i=1 of size n e × |θ|, we compute the gradient covariance matrix C e of size |θ| × |θ|: To reduce the distribution shifts in the network f θ across domains, we bring the domain-level gradient covariances closer. Hence, our Fishr regularization is: the square of the Frobenius distance between the covariance matrices from the different domains e ∈ E and the mean covariance matrix C = 1 |E| e∈E C e . Balanced with a hyperparameter coefficient λ > 0, this Fishr penalty complements the original ERM objective, i.e., the empirical training risks: In practice, L is minimized via SGD with |E| batches simultaneously -one from each domain. Motivations We choose to tackle gradient covariances for two main reasons. The first reason is because this strategy aligns gradient distributions across domains -as discussed and motivated in Section 3.2. Second and most importantly, our approach is driven by the links between the gradient covariance, the Fisher Information and the Hessian: we show in Section 3.3 that Fishr aligns the domain-level Hessians and the domain-level loss landscapes at convergence. The first motivation for Fishr is the independence between the gradient and the domain random variables. This is achieved by matching the covariance of the empirical gradient distributions g i e ne i=1 across training domains e ∈ E. Indeed, covariance is an efficient and well-suited statistics to align distributions. This was recently highlighted by the success of the covariance-based CORAL (Sun et al., 2016) on the DomainBed benchmark: matching covariance performed better than adversarial methods to align feature distributions. Therefore, this motivates the use of gradient covariance: more complex strategies to align gradient distributions are best left for future works. We now provide four -perhaps intuitive -reasons to enforce distributional invariance in gradients rather than in features. First and foremost, having similar domain-level gradient distributions is critical so that the DNN has shared properties across domains. Indeed, gradient disagreements and pairwise relations are key to the optimization procedure of DNNs: for instance, gradient confusion slows down convergence (Sankararaman et al., 2020) even though gradient diversity improves generalization (Yin et al., 2018) . As a side note, the gradient mean can capture the average learning direction but can not capture these refined statistics. Second, gradients are more expressive and richer than features. Specifically, gradients were shown to better cluster semantically close inputs (Fort et al., 2019; He & Su, 2020) . When comparing the features extracted for two inputs (Johnson et al., 2016 ), a small difference in activation may be multiplied by large subsequent weights and lead to distant predictions. On the contrary when comparing gradients, each activation is weighted by its true importance for the prediction (Charpiat et al., 2019) . Third, gradients take into account the label Y , which appears as an argument for the loss . Hence, gradient-based approaches are 'label-aware' by design. In contrast, seminal feature-based methods were shown to fail in case of label shift, i.e., when P (Y |E) = P (Y ), because they do not consider Y (Johansson et al., Zhao et al., 2019) . Lastly, matching gradient distributions also matches training risks, as motivated in V-REx (Krueger et al., 2021) for OOD generalization. Indeed, gradient amplitudes are directly weighted by the loss values. This is further justified and empirically validated in Appendix B.2.1, where we show that Fishr induces |R A − R B | 2 → 0 when E = {A, B}. Nonetheless, it is worth noting that the main drawback of using gradients is the computational overhead: this will be tackled in Section 4. In Section 3.3.1, we argue (based on previous empirical works) and show (in Table 1 ) that Fishr matches the domain-level Hessians. Then, we justify in Section 3.3.2 why having similar Hessians across domains reduces inconsistencies in the loss landscape and improves generalization. The Hessian matrix H = n i=1 ∇ 2 θ f θ (x i ), y i captures the second-order derivatives of the objective and is of key importance for deep learning methods (Ghorbani et al., 2019) . Yet, H is computationally demanding and can not be tackled directly in general. Recent methods (Izmailov et al., 2018; Foret et al., 2021) tackle the Hessian indirectly by modifying the learning procedure. In contrast, Fishr uses C to regularize H: the gradient covariance, computable efficiently with an unique backpropagation, serves as a proxy for the Hessian. Said differently, we use the second moment of the first-order derivatives to regularize the second-order derivatives. This approach is based on the close relations between the gradient covariance C and the empiri- , 2014) where p θ (·|x) is the density predicted by f θ on input x (and P θ (·|X) the model distribution). Indeed, when is the negative log-likelihood,F is the unnormalized uncentered covariance matrix. Thus, C andF are equivalent (up to the multiplicative constant n) at any firstorder stationary point: so C ∝ ∼F . Moreover,F is commonly used to estimate H: e.g., in compression (Frantar et al., 2021; Liu et al., 2021) and optimization (Dangel et al., 2021) tasks. Indeed,F and H share the same structure and are similar up to a scalar factor:F ∝ ∼ H (Thomas et al., 2020) . This similarity was discussed in and highlighted even at early stages of training (before overfitting) in the Fig. 1 and the Appendix S3 of Singh & Alistarh (2020). Overall, this suggests that C and H are closely related. In our multi-domain framework, we define the domainlevel matrices with the subscript e. Table 1 empirically confirms that matching {C e } e∈E with Fishr forces the domainlevel Hessians {H e } e∈E to be aligned at convergence. This will be further validated in Fig. 3 and in Appendix B.2.2, on the matrices' diagonal in the classifier layer for computational reasons (see Section 4) . Even so, we acknowledge that approximating H byF is not fully justified theoretically. Only the 'true' Fisher , 1998) approximates H with theoretically probably bounded errors under mild assumptions (see Theorem 1 in Kunstner et al. (2019) or Appendix 1 in Jastrzebski et al. (2018) ). The difference between F andF is that F uses the model distribution P θ (·|X) whereasF uses the data distribution P (Y |X). One would hope that when overfitting occurs, P θ (·|X) → P (Y |X) and then F →F . Though, this requires strong assumptions such as χ 2 convergence (see Proposition 1 in Thomas et al. (2020) ). Thus, despite its name,F is not a theoretically justified estimation of F (Kunstner et al., 2019). Yet, using domain-level {F e } e∈E ignores the label Y and demands additional costly backpropagations (one per class): this is best left for future work. In this paper, we trade off theoretical guarantees for computational efficiency and consider C andF . Notably, Appendix B.2.5 shows that matching {C e } e∈E or {F e } e∈E -i.e., centering or not the matrices -perform similarly. Figure 2 : Loss landscapes around an inconsistent θ * at convergence. . This inconsistency is due to conflicting domain-level loss landscapes, and is visible in the disagreements across the covariances of To understand why having similar domain-level Hessians at convergence improves generalization, we leverage the inconsistency formalism developed in AND-mask (Parascandolo et al., 2021) . They argue that "patchwork solutions sewing together different strategies" for different domains may not generalize well: good weights should be optimal on all domains and "hard to vary" (Deutsch, 2011). They formalize this intuition with an inconsistency score there is a path in the weights space between θ and θ * where the risk R A remains in an > 0 interval around R A (θ * ). I increases with conflicting geometries in the loss landscapes around θ * as in Fig. 2: i.e., when another 'close' solution θ is equivalent to the current solution θ * in a domain A but yields different losses in B. Moreover, the Hessian approximates the local curvature of the loss landscape around θ * . Assuming that domain-level Hessians are co-diagonalizable for simplicity, inconsistency increases when an eigenvalue is small in H A but large in H B . Indeed, as shown in Appendix A.1 with a simple secondorder Taylor expansion, a small weight perturbation in the direction of the associated eigenvector would change the loss slightly in the domain A but drastically in B. In conclusion, "inconsistency is lowest when shapes [. . .] are similar" (Parascandolo et al., 2021), i.e., when H A = H B . ANDmask minimizes I by zeroing out gradients with inconsistent directions across domains. However, this masking strategy introduces dead zones (Shahtalebi et al., 2021) in weights where the model could get stuck, ignores gradient magnitudes and empirically performs poorly with 'real' datasets from DomainBed. In place, Fishr reduces inconsistencies in the loss landscapes across domains by matching the domain-level Hessians, using domain-level gradient covariances as a proxy. Finally, we refer the readers to Appendix A.2 where we leverage the Neural Tangent Kernel (NTK) (Jacot et al., 2018) theory to further motivate the gradient covariance matching during the optimization process -and not only at convergence. In brief, as F and the NTK matrices share the same non-zero eigenvalues, similar {C e } e∈E during training reduce the simplicity bias by preventing the learning of different domain-dependent shortcuts at different training speeds: this favors a shared mechanism that predicts the same thing for the same reasons across domains. We now describe three implementation choices that make Fishr easily applicable at almost the same computational costs as ERM. First, to reduce the memory overhead, we only consider the gradients in the final linear classifier It's worth noting that these gradients in ω depend on Φ φ . We ignore the gradients in the features extractor Φ φ because low-level layers may need to adapt to domain-dependent peculiarities: on the contrary, invariance in high-level features is more critical. This intuition -strengthened in Appendix C.2.2 -is classical in unsupervised domain adaptation (Ganin et al., 2016) and for computing influence functions (Ye et al., 2021a; Pezeshkpour et al., 2021) . This is also consistent with concurrent approaches in OOD generalization that seek invariant predictors (Arjovsky et al., 2019; Krueger et al., 2021) . Moreover, if we used all weights θ = ω ⊕ φ to compute C, the invariance in w ω would get overshadowed by Φ φ due to |ω| |φ|. Second, we only use the gradient variancei.e., the diagonal of the covariance matrix Diag(C) -to scale down the number of targeted components from |ω| 2 to |ω|. This is similar to AND-mask ( and pruning (Lecun et al., 1990; Theis et al., 2018) . Since we highlighted empirical similarities between H and C, we ignore the off-diagonal parts in C. This is also motivated by the critical importance of Tr(H) (Yao et al., 2020) and Tr(F ) (Jastrzebski et al., 2021) to analyze the generalization properties of DNNs. In Appendix B.2.4, we empirically observe that ignoring or not the off-diagonal parts performs similarly on Colored MNIST. Future work could consider refined interactions between weights using K-FAC approximations (Heskes, 2000; Martens & Grosse, 2015) . In summary, we match the variances of gradients for each weight in the classifier: where the scalar v π = 1 |E| e∈E v π e averages the domain-level variances per classifier weight π: Third, PyTorch (Paszke et al., 2019) is optimized only to compute the batch-averaged gradients. Fortunately, the recent BackPACK (Dangel et al., 2020) package keeps additional information from first-and second-order derivatives. Notably, it supports efficient computation of individual gradients, sample per sample, from a batch at almost no time overhead. Moreover, the DiagHessian method in BackPACK computes the Hessian diagonals, and is used for analysis in Table 1 and Fig. 3 : yet, "Hessian is an order of magnitude more computationally intensive" than individual gradients (see Fig. 9 in Dangel et al. (2020)) and is not yet practical directly as a training objective. In conclusion, we first compute the individual derivatives of ERM in w ω using BackPACK to get {Diag(C e )} e∈E and second backpropagate the loss from Eq. 4 in the whole network (without keeping the individual gradients). Fishr is simple to implement (see the pseudo-code in Appendix D) and at low computational costs. For example, on PACS dataset (7 classes and |ω| = 14, 343) with a ResNet-50 and batch size 32, Fishr induces an overhead in memory of +0.2% and in training time of +2.7% (with a Tesla V100) compared to ERM; on the larger-scale DomainNet (345 classes and |ω| = 706, 905), the overhead is +7.0% in memory and +6.5% in training time. As a side note, keeping the full covariance of size |ω| 2 ≈ 5 × 10 8 on DomainNet would not have been possible. We first tackle Colored MNIST in the IRM (Arjovsky et al., 2019) setup and then the DomainBed benchmark (Gulrajani & Lopez-Paz, 2021). To facilitate reproducibility, the code is available at https://github.com/alexrame/fishr. The task in Colored MNIST is to predict whether the digit is below or above 5. Moreover, the labels are flipped with 25% probability (except in Appendix B.2.3). Critically, the digits' colors spuriously correlate with the labels: the correlation strength varies across the two training domains E = {90%, 80%}. To test whether the model has learned to ignore the color, this correlation is reversed at test time. In brief, a biased model that only considers the color would have 10% test accuracy whereas an oracle model that perfectly predicts the shape would have 75%. As previously done in V-REx (Krueger et al., 2021), we strictly follow the IRM implementation and just replace the IRM penalty by our Fishr penalty. This means that we use the exact same MLP and hyperparameters, notably the same two-stage scheduling selected in IRM for the regularization strength λ, that is low until epoch 190 and then jumps to a large value: more details in Appendix B.1. Table 2 reports the Top-1 classification accuracy averaged over 10 runs with standard deviation. In test, Fishr reaches 69.5%, and 70.2% when digits are grayscale. This highlights Fishr effectiveness even without approach-dependent hyperparameter tuning, but should not be considered as a proof of Fishr superiority over other approaches precisely because of the absence of hyperparameter search. but significantly increases test accuracy (blue) as the network learns to predict the digit's shape. The main advantage of this synthetic dataset is the possibility of empirically validating some theoretical insights. Notably, the training dynamics in Fig. 3 show that the domain-level Hessians get closer once the Fishr gradient variance matching loss is activated after step 190. Consequently, this sharply increases test accuracy. This confirms insights from Section 3.3.1 Additional experiments can be found in Appendix B.2. Yet, the main drawback of Colored MNIST is its insufficiency to ensure generalization for real-world datasets. Overall, it should be considered as a first proof-of-concept. Fishr relies on three hyperparameters. First, the λ coefficient controls the regularization strength: with λ = 0 we recover ERM while an high λ may cause underfitting. Second the warmup iteration defines the step at which we activate the regularization. This warmup strategy is taken from previous works such as IRM (Arjovsky et al., 2019), V-REx (Krueger et al., 2021) or Spectral Decoupling (Pezeshki et al., 2020) . Before that step, the DNN is trained with ERM to learn predictive features. After that step, the Fishr regularization encourages the DNN to have invariant gradient variances. Lastly, the domain-level gradient variances are more accurate when estimated over more data points. Rather than increasing the batch size, we follow Le Roux et al. (2011) and leverage an exponential moving average for computing stable gradient variances. Therefore our third hyperparameter is the coefficient γ that controls the update speed: at step t, we match the diagonals of C t e = γC t−1 e + (1 − γ)C t e rather than of C t e from Eq. 2. The closer γ is to 1, the smoother the variance is along training.C t−1 e from previous step t − 1 is 'detached' from the computational graph. Similar strategies have already been used for OOD generalization in Nam et al. (2020); Blanchard et al. (2021); Zhang et al. (2021) . The memory overhead is |E| * |ω| and was (already) taken into account in the computational comparisons at the end of Section 4. Finally, we study by ablation the importance of this warmup strategy and this γ in Appendices C.2.1 and C.2.2. To limit access to test domain, the framework enforces that all methods are trained with only 20 different configurations of hyperparameters and for the same number of steps without early stopping. Results are averaged over three trials. The DomainBed experimental setup is further described in Appendix C.1; the hyperparameter distributions are analyzed in Appendix C.2.3; results are detailed per dataset in Appendix C.3. In Tables, we format first and second best results. As performances depend heavily on the hyperparameter choice, the model selection strategy is critical. Table 3 summarizes the results on DomainBed using the 'Oracle' model selection: the validation set follows the same distribution as the test domain. ERM remains a strong baseline and all previous methods are far from the best score on at least one dataset. Moreover, 'invariant predictors' and 'gradient masking' approaches perform poorly on 'real' datasets. Contrarily, Fishr is the only method to systematically perform better than ERM on all 'real' datasets: the differences are over standard errors on VLCS ( In Table 4 , the validation set is formed by collecting 20% of each training domain. With this 'Training' model selection, Fishr performs better than ERM on all 'real' datasets (over standard errors for OfficeHome and DomainNet), except for PACS where the two reach 85.5%. In average, Fishr (67.1%) finishes third and is above most methods such as V-REx (65.6%). Limitations Although Fishr remains stronger than ERM in the 'Training' setup, the improvements are smaller than in 'Oracle'. Indeed, besides the arguably low number of hyperparameter trials (20), the 'Training' setup suffers from underspecification: "predictors with equivalently strong held-out performance in the training domain [...] can behave very differently" in test (D'Amour et al., 2020) . To reduce underspecification, future benchmarks may consider the training calibration (Wald et al., 2021) during the model selection rather than relying only on the training accuracy. In this paper, we addressed the task of out-of-distribution generalization for classification in computer vision. Motivated by the empirical success of CORAL and the inconsistency formalism from Parascandolo et al. (2021), we derive a new and simple regularization -Fishr -that matches the gradient variances across domains as a proxy for matching domain-level Hessians. This reaches state-of-the-art performances on DomainBed when samples from the test domain are available for model selection. Our empirical experiments suggest that Fishr would consistently improve a deep classifier in real-world applications when dealing with data from multiple domains. More generally, the criterion of domain invariance in gradients opens up new perspectives: for example, future work could consider adversarial strategies (Goodfellow et al., 2014) to align gradient distributions. This work was granted access to the HPC resources of IDRIS under the allocation A0100612449 made by GENCI. We acknowledge the financial support by the ANR agency in the chair VISA-DEEP (ANR-20-CHIA-0022-01). These Appendices follow a similar order as their related sections in the main paper. The second-order Taylor expansion of R e around θ * = 0 (with a change of variable) gives: for e ∈ {A, B}. We assume simultaneous convergence, i.e., θ * is a local minima across all training domains: where the last inequality is deduced from the triangle inequality. The first term |R B (θ * ) − R A (θ * )| was simply assumed small at convergence in AND-mask (Parascandolo et al., 2021) . We further justify this approximation for Fishr by reminding that the empirical risks difference across domains is the V-REx (Krueger et al., 2021) criterion: thus, as argued in Section 3.2 and as shown in Appendix B.2.1, Fishr forces this first term to be low at convergence. Following Parascandolo et al. (2021) , the second term is more easily understood when Hessians are diagonal: H e = diag (λ e 1 , · · · , λ e n ) with ∀i ∈ {1, . . . , |θ|} , λ e i > 0. In this case, max | 1 decreases when H A and H B have similar eigenvalues. In conclusion, Fishr reduces inconsistency by matching (1) domain-level empirical risks and (2) domain-level Hessians across the training domains. In this section we motivate the matching of gradient covariances with new arguments from the Neural Tangent Kernel (NTK) (Jacot et al., 2018) theory. As a reminder, the NTK K ∈ R n×n is the gramian matrix with entries K[i, j] = ∇ θ f θ (x i ) T · ∇ θ f θ (x j ) that measure the gradients similarity at two different input points x i and x j . This kernel dictates the training dynamics of the DNN and remains fixed in the infinite width limit. Most importantly, as stated in Yang & Salman (2019) , "the simplicity bias of a wide neural network can be read off quickly from the spectrum of K: if the largest eigenvalue [λ max ] of K accounts for most of Tr(K), then a typical random network looks like a function from the top eigenspace of K": this holds for ReLu networks. In summary, gradient descent mostly happens in a tiny subspace (Gur-Ari et al., 2018) whose directions are defined by the main eigenvectors from K. Moreover, the learning speed is dictated by λ max , which can be used to estimate a condition for a learning rate η to converge: η < 2/λ max (Karakida et al., 2019) . In a multi-domain framework, having similar spectral decompositions across {K e } e∈E during the optimization process would improve OOD generalization for two reasons: Directly matching K e would require assuming that each domain coincides and contains the same samples; for example, with different pose angles (Ghifary et al., 2015) . To avoid such a strong assumption, we leverage the fact that the 'true' Fisher Information Matrix F and the NTK K share the same non-zero eigenvalues since F is dual to K (see Appendix C. , it has 2 main differences. First, 0-4 and 5-9 digits are each collapsed into a single class, with a 25% chance of label flipping. Second, digits are either colored red or green, with a strong correlation between label and color in training. However, this correlation is reversed at test time. Specifically, in training, the model has access to two domains E = {90%, 80%}: in the first domain, green digits have a 90% chance of being in 5-9; in the second, this chance goes down to 80%. In test, green digits have a 10% chance of being in 5-9. Due to this modification in correlation, a model should ideally ignore the color information and only rely on the digits' shape: this would obtain a 75% test accuracy. In the experimental setup from IRM, the network is a 3 layers MLP with ReLu activation, optimized with Adam (Kingma & Ba, 2014). IRM selected the following hyperparameters by random search over 50 trials: hidden dimension of 390, l 2 regularizer weight of 0.00110794568, learning rate of 0.0004898536566546834, penalty anneal iters (or warmup iter) of 190, penalty weight (λ) of 91257.18613115903, 501 epochs and batch size 25,000 (half of the dataset size). We strictly keep the same hyperparameters values in our proof of concept in Section 5.1. Our code is almost unchanged from https://github.com/facebookresearch/InvariantRiskMinimization. Figure 4 : Risks dynamics on Colored MNIST with Fishr. At epoch 190, λ steps us and then domain-level empirical risks R 90% and R 80% get closer. We argue in Section 3.2 that gradient amplitudes are directly related to the loss values. Indeed, the constant multiplier rule states that multiplying the loss by a constant will also multiply the gradients by the same constant. Thus, forcing gradients to be similar should bring the domain-level empirical training risks closer. Fig. 4 Thomas et al., 2020) , we argue in Section 3.3.1 that gradient covariance C can be used as a proxy to regularize the Hessian Heven though the proper approximation bounds are out of scope of this paper. This was empirically validated at convergence in Table 1 and during training in Fig. 3 . For computational reasons, we compute Hessian diagonals with the DiagHessian method from BackPACK in the classifier. This appendix further analyzes the Hessian during training. Fig. 5 illustrates the dynamics for Fishr: following the scheduling previously described in Appendix B.1, λ jumping to a high value at epoch 190 activates the regularization. After this epoch, the domain-level Hessians are not only close in Frobenius distance, but also have similar norms and directions. On the contrary, when using only ERM in Fig. 6 , the distance between domain-level Hessians keeps increasing with the number of epochs. As a side note, flatter loss landscapes in ERM -as reflected by the Hessian norms in orange -do not correlate with improved generalization (Dinh et al., 2017) . To further validate that Fishr can tackle distribution shifts, we investigate Colored MNIST but without the 25% label flipping. In Table 5 , the label is then fully predictable from the digit shape. Using hyperparameters defined previously in Appendix B.1, IRM (82.2%) performs worse than ERM (91.8%) while V-REx and Fishr perform better (95.3%): Fishr works even without label noise. We have explained in Section 4 that we ignore the off-diagonal parts of the covariance to reduce the memory overhead. For the sake of completeness, the second line in Table 6 shows results with the full covariance matrix. Overall, results are similar (or slightly worse) as when using only the diagonal: the slight difference may be explained by the approaches' different suitability to the hyperparameters (that were optimized for IRM). In conclusion, this preliminary experiment suggests that targeting the diagonal components is the most critical. We hope future works will further investigate this diagonal approximation or provide new methods to reduce the computational costs. In Section 3.3.1, we argue that the gradient centered covariance C and the empirical Fisher Information Matrix (or uncentered covariance)F are highly related and equivalent when the DNN is at convergence and the gradient means are zero. Thus, we could have tackled the diagonals of the domain-level {F e } e∈E across domains, i.e., do not center the variance. Empirically, comparing the first and third lines in Table 6 shows that centering or not the variance are almost equivalent. We now further detail our experiments on the DomainBed benchmark. Scores from most baselines are taken from the DomainBed (Gulrajani & Lopez-Paz, 2021) github, at commit 0x7df 6f 06. Scores for AND-mask and SAND-mask are taken from the SAND-mask paper (Shahtalebi et al., 2021) . For Fish (Shi et al., 2021) , averaged 'Training' scores are taken from the arXiv paper and averaged 'Oracle' scores are from direct messages with the authors: however, the per-dataset results are not available. Scores for IGA (Koyama & Yamaguchi, 2020) are not yet available and are very computationally expensive: yet, for the sake of completeness, we analyze IGA in Appendix C.2.2. Missing scores will be included when available. For a fair comparison, we have included Fishr as new algorithm in the DomainBed benchmark https://github.com/facebookresearch/DomainBed. Thus, the same procedure was applied for all methods. In brief, for each domain, a random hyperparameter search of 20 trials over a joint distribution, described in Table 7 , is performed. Note that we discuss the choice of these distributions in Appendix C.2.3. The learning rate, the batch size (except for ARM), the weight decay and the dropout distributions are shared across all methodsall trained with Adam (Kingma & Ba, 2014). All hyperparameter distributions for methods can be found at https://github.com/facebookresearch/DomainBed/blob/master/ domainbed/hparams_registry.py. The data from each domain is split into 80% (used as training and testing) and 20% (used as validation for hyperparameter selection) splits. This random process is repeated with 3 different seeds: the reported numbers are the means and the standard errors over these 3 seeds. We focus on the two 'Oracle' and 'Training' model selection methods and have not run the 'Leave-one-domain-out Cross-validation' for computational reasons. We clarify a subtle point concerning the hyperparameter γ that controls the update speed of the covariance:C t e = γC t−1 e + (1 − γ)C t e at step t. We remind thatC t−1 e from previous step t − 1 is 'detached' from the computational graph. Thus when L from Eq. 4 is differentiated during SGD, the gradients going through C t e are multiplied by (1 − γ) . To compensate this and decorrelate the impact of γ and of λ (that controls the regularization strength), we match 1 1−γC t e . Finally, with this (1 − γ) correction, the gradients' strength backpropagated in the network is independent of γ. Here we list all concurrent approaches. • Fish: Gradient Matching for Domain Generalization (Shi et al., 2021) Neural network architectures used for each dataset are shown in Table 8a . Table 8b describes the convolutional neural network architecture used for MNIST experiments: note that this is not the same MLP (described in Appendix B.1) as in our proof of concept in Section 5.1. The 'ResNet-50' network is pretrained on ImageNet, has a dropout layer before the newly added dense layer and is fine-tuned on the new datasets with frozen batch normalization layers. GroupNorm (groups = 8) 4 Conv2D (in = 64, out = 128, stride = 2) 5 ReLU 6 GroupNorm (8 groups) 7 Conv2D (in=128, out=128) 8 ReLU 9 GroupNorm (8 groups (2011), we use an exponential moving average (ema) parameterized by γ for computing gradient variances in DomainBed: the closer γ is to 1, the longer a batch will impact the variance from later steps. We now further analyze the impact of this strategy, which is not specific to Fishr and was used previously in other works (Nam et al., 2020; Blanchard et al., 2021; Zhang et al., 2021) for OOD generalization. Notably, this ema strategy could be applied to better estimate domain-level empirical risks in V- REx (Krueger et al., 2021) . For a fair comparison, we introduce a new approach -V-REx with ema -that penalizes |R t A −R t B | 2 at step t wherē R t e = γR t−1 e + (1 − γ)R t e when E = {A, B}. Thus, we compare V-REx and Fishr, with γ = 0 () or with γ ∼ Uniform(0.9, 0.99) (, as described in Table 7 ). On the synthetic Colored MNIST in Table 9 , the ema is critical for Fishrnotably when training on E = {90%, 80%} and the dataset 10% is in test (from 34.0% to 58.9% in 'Oracle'). V-REx also benefits from ema. On the 'real' dataset OfficeHome in Table 10 , the ema is less beneficial (from 67.5% to 68.2% in 'Oracle' for Fishr). Notably, it worsens V-REX. Overall, Fishr -with and without ema -outperforms V-REx on OfficeHome. We speculate that ema mainly helps when the batch size is not sufficiently large to detect 'slight' correlation shifts in the training datasets: e.g., when batch size ∼ 2 Uniform (3, 9) and training datasets E = {90%, 80%} in Colored MNIST. We remind that when the batch size was 25,000 in the Colored MNIST setup from IRM, Fishr reached 69.5% (without ema) in Table 2 from Section 5.1. On the contrary, when the shift is more prominent as in OfficeHome, the ema may be less necessary. Most importantly, Fishr -with and without ema -improves over ERM on these datasets. As a reminder from the Section 2, IGA (Koyama & Yamaguchi, 2020) is an unpublished gradientbased approach that matches gradient means across domains, i.e., minimizes ||g A − g B || 2 2 when E = {A, B} and where g e = 1 ne ne i=1 ∇ θ (f θ (x e ), y e ). Scores for IGA are not available publicly and thus were not included in Section 5.2. Moreover, IGA is very costly and impractical: IGA is approximately |E| + 1 times longer to train than ERM. Yet, we ran the DomainBed implementation of IGA on one 'synthetic' and one 'real' dataset. Table 11 shows that the IGA has little effect on Colored MNIST (58.0% vs. 57.8% for ERM in 'Oracle'). Moreover, on OfficeHome in Table 12 , IGA hinders learning (56.9% vs. 66.4% for ERM in 'Oracle'). In brief, the seminal "IGA [. . .] could completely fail when generalizing to unseen domains", as stated in Fish (Shi et al., 2021) . In the rest of this section, we include IGA in Fishr codebase so that both methods leverage the same implementation choices: this enables fairer comparisons between gradient mean matching and gradient covariance matching. These experiments provide further insights regarding Fishr main components: specifically, enforcing invariance (1) only in the classifier's weights ω (2) after a warmup period and (3) with an exponential moving average. First, Fishr only considers gradient covariances in the classifier's weights ω. Similarly, we try to apply IGA's gradient mean matching but only in w ω rather than in f θ . This new method works significantly better (67.2% when g e = 1 ne ne i=1 ∇ ω (f θ (x e ), y e ) vs. 56.9% when g e = 1 ne ne i=1 ∇ θ (f θ (x e ), y e ) for 'Oracle' OfficeHome in Table 12 ) while reducing the computational overhead. This further motivates the invariance in the classifier rather than in the low-level layers (which need to adapt to shifts in pixels for instance). We have done this analysis on IGA and not on Fishr because keeping individual gradients from the whole network f θ in the GPU memory was not possible with our hardware. Third, the estimation of gradient variances was improved with an exponential moving average (see Section 5.2 and Appendix C.2.1) . We now use this strategy with domain-level gradient means for IGA in ω:ḡ t e = γḡ t−1 e + (1 − γ)g t e . This improves IGA (from 67.0% to 67.2% in 'Oracle' on OfficeHome): yet, these scores remain consistently worse than Fishr's (from 67.5% to 68.2%). In conclusion, this complements the experiments in Section 5.2 which showed that tackling gradient covariance does better than tackling gradient mean: indeed, Fishr performed better than Fish (Shi et al., 2021) , AND-mask (Parascandolo et al., 2021) and SAND-mask (Shahtalebi et al., 2021) . As a final note, Fishr + IGA -i.e., matching simultaneously gradient means (the first moment) and covariances (the second moment) -performs best. Future works may further analyze the complementary of these gradient-based methods. This section is a preliminary introduction to a meta-discussion, not about the methodology to select the best hyperparameters, but about the methodology to select the hyperparameter distributions in DomainBed. This question has not been discussed in previous works (as far as we know). After few initial iterations on the main idea of the paper, we had to select the distributions to sample our three hyperparameters from, as described in Table 7 . First, to select the ema γ distribution, we knew that the authors from Le Roux et al. (2011) have not noticed "any significant difference in validation errors" for different values higher than 0.9. Moreover γ should remain strictly lower than 1. Thus, sampling from Uniform(0.9, 0.99) seemed appropriate. Second, sampling the number of warmup iterations uniformly along training from Uniform(0, 5000) seemed the most natural and neutral choice. Lastly, the choice of the λ distribution was more complex. As a reminder, a low λ inactivates the regularization while an extremely high λ may destabilize the training. (1, 5) 68.7 ± 1. In Table 13 , we investigate two distributions: λ ∼ 10 Uniform (1, 4) (eventually chosen for Fishr) and λ ∼ 10 Uniform (1, 5) . First, we observe that results are mostly similar: it confirms that Fishr is consistently better than ERM (where λ = 0), and in average is the best approach with the 'Oracle' model selection and among the best approaches with the 'Training' model selection. Second, the existence of consistent differences in results suggests that the best hyperparameter distribution depends on the dataset at hand and that the performance gap depends on the selection method. While out of the scope of this paper, we believe these results were important for transparency (along with publishing our code), and may motivate the need for new protocols -for example with bayesian hyperparameter search (Turner et al., 2021) -that future benchmarks may introduce. Tables below detail results for each dataset with 'Oracle' and 'Training' model selection methods. for domains e ∈ E, regularization weight λ, warmup iteration i warmup and exponential moving average speed γ, batch size b s , optimizer g, learning rate l r . / * Training Procedure * / 1 Initialize moving averages: ∀e ∈ E, v mean e ← 0 2 for iter from 1 to #iters do / * Step 1: standard ERM procedure * / Backpropagate gradients: θ ← g (gradient = ∇ θ L(θ), learning rate = l r ) Ellipsoidal trust region methods and the marginal value of hessian information for neural network training Moment matching for multi-source domain adaptation Causal inference by using invariant prediction: identification and confidence intervals Gradient starvation: A learning proclivity in neural networks An empirical comparison of instance attribution methods for nlp Common pitfalls and recommendations for using machine learning to detect and prognosticate for covid-19 using chest radiographs and ct scans Invariant models for causal transfer learning The risks of invariant risk minimization Distributionally robust neural networks An investigation of why overparameterization exacerbates spurious correlations The impact of neural network overparameterization on gradient confusion and stochastic gradient descent The pitfalls of simplicity bias in neural networks Sand-mask: An enhanced gradient masking strategy for the discovery of invariances in domain generalization Gradient matching for domain generalization Woodfisher: Efficient second-order approximation for neural network compression Deep coral: Correlation alignment for deep domain adaptation Return of frustratingly easy domain adaptation Building machines that learn and think like people Unshuffling data for improved generalization Faster gaze prediction with dense networks and fisher pruning On the interplay between noise and curvature and its effect on optimization and generalization Bayesian optimization is superior to random search for machine learning hyperparameter tuning: Analysis of the black-box optimization challenge 2020 Deep domain confusion: Maximizing for domain invariance Deep learning generalizes because the parameter-function map is biased towards simple functions An overview of statistical learning theory Deep hashing network for unsupervised domain adaptation On calibration and out-of-domain generalization Heterogeneous domain generalization via domain mixup Dual mixup regularized learning for adversarial domain adaptation Inn-out: Pre-training and self-training using auxiliary information for out-of-distribution robustness Improve unsupervised domain adaptation with mixup training A fine-grained spectral perspective on neural networks Pyhessian: Neural networks through the lens of the hessian Out-of-distribution generalization analysis via influence function Ood-bench: Benchmarking and understanding out-of-distribution generalization datasets and algorithms Gradient diversity: a key ingredient for scalable distributed learning Gradient surgery for multi-task learning Adaptive risk minimization: A meta-learning approach for tackling group distribution shift Deep stable learning for out-of-distribution generalization Learning novel policies for tasks On learning invariant representations for domain adaptation DomainBed includes seven multi-domain computer vision classification datasets As described previously in Appendix B.1, domain d ∈ {90%, 80%, 10%} contains a disjoint set of digits colored: the correlation strengths between color and label vary across domains. The dataset contains 70,000 examples of dimension (2, 28, 28) and 2 classes 2015) is a variant of MNIST where domain d ∈ {0, 15, 30, 45, 60, 75} contains digits rotated by d degrees 2017) includes domains d ∈ {art, cartoons, photos, sketches}, with 9 2017) includes domains d ∈ {art, clipart, product, real}, with 15,588 examples of dimension (3, 224, 224) and 65 classes 2018) contains photographs of wild animals taken by camera traps at locations d ∈ {L100 2019) has six domains d ∈ {clipart, infograph, painting, quickdraw, real, sketch}