key: cord-0178790-l12moj30 authors: McInerney, Denis Jered; Kong, Luyang; Arumae, Kristjan; Wallace, Byron; Bhatia, Parminder title: Kronecker Factorization for Preventing Catastrophic Forgetting in Large-scale Medical Entity Linking date: 2021-11-11 journal: nan DOI: nan sha: 3762263b72fe0fe7f2523fd5e49f7fd3635a5ab9 doc_id: 178790 cord_uid: l12moj30 Multi-task learning is useful in NLP because it is often practically desirable to have a single model that works across a range of tasks. In the medical domain, sequential training on tasks may sometimes be the only way to train models, either because access to the original (potentially sensitive) data is no longer available, or simply owing to the computational costs inherent to joint retraining. A major issue inherent to sequential learning, however, is catastrophic forgetting, i.e., a substantial drop in accuracy on prior tasks when a model is updated for a new task. Elastic Weight Consolidation is a recently proposed method to address this issue, but scaling this approach to the modern large models used in practice requires making strong independence assumptions about model parameters, limiting its effectiveness. In this work, we apply Kronecker Factorization--a recent approach that relaxes independence assumptions--to prevent catastrophic forgetting in convolutional and Transformer-based neural networks at scale. We show the effectiveness of this technique on the important and illustrative task of medical entity linking across three datasets, demonstrating the capability of the technique to be used to make efficient updates to existing methods as new medical data becomes available. On average, the proposed method reduces catastrophic forgetting by 51% when using a BERT-based model, compared to a 27% reduction using standard Elastic Weight Consolidation, while maintaining spatial complexity proportional to the number of model parameters. Creating a single model that performs well across multiple domains is often desirable, especially in production systems. Relying on multiple (task-specific) systems necessitates storing and managing corresponding collections of parameters. Multi-task models Caruana [1997] obviate this need by performing well on inputs from all tasks, simplifying deployment. In the medical domain especially, new data is constantly becoming available, and it is necessary to keep models up to date with this deluge. In medical entity linking-where the goal is to link mentions in clinical text to corresponding entities in an ontology-the underlying ontologies are frequently updated, and the new terms are put into use quickly. For example, in the past year or so many codes related to COVID-19 Guan et al. [2020] were added to the International Classification of Diseases (ICD) lexicon, therefore updating models to incorporate new codes without losing performance on older knowledge is of particular importance. The language in the medical domain also brings additional challenges because of the many acronyms, synonyms, and ambiguous terms used in both clinical and biomedical corpora. These characteristics make it difficult even for humans to choose the correct entity among the top candidates in an ontology. To train a multi-task model, one would ideally jointly train on data drawn from all tasks. However, when new tasks are introduced, this would require re-training on the combined data, which is inefficient and sometimes practically impossible. For example, there are cases-particularly when dealing with medical data-where data access is lost, precluding joint training over a combined set of old and new data. This occurs, e.g., when license agreements or contracts with data providers expire; time-limited data partnerships are common in industry settings. In these situations, continuous learning, or training on task-specific data sequentially, is the only option to maintain a single model across tasks. Previous work on continuous learning has focused on mitigating catastrophic forgetting (CF) Mc-Closkey and Cohen [1989] , Ratcliff [1990] , a problem that arises in sequential training where performance on earlier tasks drops when the model is trained on additional tasks. Experience Replay de Masson d' Autume et al. [2019] maintains performance when training on new tasks by "replaying" examples saved from older tasks. Unfortunately, in a setting in which access to previous data is not possible, this method cannot be used. Elastic Weight Consolidation (EWC) Kirkpatrick et al. [2017] is an alternative, constraint-based technique that regularizes parameters such that they are encouraged to maintain optimal weights learned for prior tasks. By placing a prior involving the Hessian from previous tasks over network parameters, EWC affords flexibility with respect to changing parameters in different dimensions. Critically, EWC does not require continued access to data from 'past' tasks once key statistics are computed over a task's data. However, to scale to even relatively small neural networks, EWC must assume independence between all parameters. This assumption allows one to drop off-diagonal terms in the Fisher Information Matrix (FIM), which is used to approximate the Hessian; calculating the full matrix would be intractable. Recent work Ritter et al. [2018b] has proposed Kronecker Factorization (KF) -which the optimization community uses to compute Hessians in neural networks -to perform a version of EWC with a relaxed independence assumption on networks of linear layers operating on small vision datasets. Our contribution here is the extension of EWC and KF to the large-scale neural models now common in NLP. In particular, modern NLP tends to rely on large-scale models with hundreds of millions of parameters Devlin et al. [2019] , Liu et al. [2019] , Radford et al. [2019] , and CF is problematic across many of its sub-domains. Though EWC has been successfully applied to NLP in recent work (Section 7), we demonstrate that there is room for substantial gains. In particular, we have observed that the independence assumption over parameters significantly and negatively affects EWC's ability to mitigate CF, as compared to what is achieved using the full covariance matrix of parameters. As far as we are aware, this work is the first application of the Kronecker Factorization methodwhich relaxes the assumption of independence between parameters-for continuous learning in large-scale networks. Though we do not model all elements of the Fisher Information Matrix, we approximate block diagonals corresponding to layers, which is less damaging than assuming completely independent parameters. Specifically, we apply Kronecker Factorization in two large-scale neural networks for Entity Linking: (1) a convolutional and (2) a transformer-based architecture Vaswani et al. [2017] . Our primary contributions are as follows: (1) We combine and extend prior work in EWC and Kronecker Factorization to modern large-scale NLP models. (2) We use Kronecker Factorization to train on multiple biomedical ontology entity linking tasks sequentially without access to previous data, and show that it significantly outperforms baseline methods. Here we compare recently proposed baselines, describe an extension of EWC, and demonstrate how to scale this to modern, large NLP models. We differentiate between methods that require access to prior task data during training and those that do not. Learning rate control was proposed in ULMFit Howard and Ruder [2018] as "discriminative finetuning" motivated by the intuition that different layers contain distinct information and should be fine-tuned at different rates, accordingly. In particular, we control the extent of fine-tuning by altering the learning rate for different layers according to η = η −1 γ where η is the learning rate for layer and γ is a constant hyper-parameter. Though the authors apply this method to fine-tuning networks when the pre-training task is not one of the downstream tasks, they suggest that this can also prevent CF, so it serves as a good baseline to our other methods. Experience Replay de Masson d' Autume et al. [2019] is a method for preventing CF in continuous learning by "replaying" examples from a memory buffer. After every m training batches on a task, we "replay" n batches from prior tasks. This requires access to all prior task data while adapting to a new task, whereas our goal is to design an approach that does not require this data. Still, we compare our methods to experience replay to assess the relative performance we can achieve when we maintain only statistics and model weights for prior tasks. Kirkpatrick et al. [2017] mitigates CF using an approximation of a posterior on the optimal weights for prior tasks in order to constrain learning on new tasks: for task t where L t = − log p(D t |θ t ) is the loss for that task. In particular, the posterior is approximated as a Gaussian Mixture Model where each mode corresponds to a prior task. Specifically, where i indexes over previous tasks, µ i and Λ i are the mean and covariance for the parameters of task i, and λ is a hyper-parameter that controls the strength of the regularization or from another view, the variance of the posterior. The mean represents the optimal weights for the task, approximated using the parameters directly after training on the task (µ i = θ * i ). We estimate the inverse of the covariance matrix with the Fisher Information Matrix (FIM) Frieden [2004] , Spall [2005] In practice, for purposes of computational efficiency, we use only the diagonal of the FIM Kirkpatrick et al. [2017] to compute the regularization term, in turn implying that where j indexes over the parameters. Unlike experience replay, this method has the benefit of not requiring access to prior task data during training, provided that the FIM is computed ahead of time. This also holds for our main method, Kronecker Factorization. Though EWC and Kronecker Factorization both require keeping a set a weights and a FIM from after training on each task, these can be discarded when performing inference. Kronecker Factorization. A downside to using a diagonal FIM is that this assumes independence between all model parameters. However, Λ −1 i is an n × n matrix where n is the number of parameters in the model, so computing this for contemporary neural networks with millions of parameters is computationally infeasible. Ritter et al. [2018b] relax this independence assumption while still maintaining low memory requirements by using a convenient Kronecker Factorization of the FIM for parameters within a layer. For a linear layer without bias W ∈ R p×q , input x ∈ R p×1 , and output y = W x where y ∈ R q×1 , the derivative of the weights can be written as ∂L ∂W = x ∂L ∂y . We elide indexing over the task and the instance for clarity. Therefore, using Equation 3, the inverse of the full covariance can be factored as Then, making an assumption of independence between the two Kronecker factors xx and ∂L ∂y ∂L ∂y , we can approximate the full covariance for the linear layer as assumes that θ0 and θ1 are independent, so the corresponding Gaussian ellipse has no tilt. The Kronecker-Factored approximation to the FIM may not exactly align with the true local Gaussian over the optimal parameters for task A, but it can still model interactions between parameters, so it has a similar tilt to the true Gaussian. Using this KFC approximation, the estimated parameters are much closer to the optimal than when using the diagonal FIM. We can extend this to layers with bias by appending a constant to the input and adding a column to the weights: Figure 1 illustrates the motivation for including off-diagonal terms in the FIM. In particular, in the optimal parameter zones for tasks A and B in the figure, estimating off-diagonal terms is beneficial in achieving the near-optimal parameters for both tasks. Spatial Complexity. The approximation to the full FIM in Equation 6 allows us to store only the factors in place of the full FIM. The spatial complexity of the full FIM is P 2 where P is the number of parameters in the network. In general, neural networks have a structure that makes parameters have O(d 2 l) complexity where the intermediate dimensions are all O(d) and l is the number of layers. This means that the number of elements in the full FIM is O(d 4 l 2 ). Assuming layer-wise independence makes this a block-diagonal FIM which has O(d 4 l) non-zero elements. Finally, because Kronecker Factorization reduces the memory to p 2 + q 2 elements instead of p 2 q 2 for an individual linear layer with input dimension p and output dimension q, the full memory usage is O(d 2 l). Given a slightly stronger assumption that P = Θ(d 2 l) instead of just O(d 2 l), which holds for the networks in this paper and most others, we see that this method, like EWC, is also linearly dependent on the parameters. We demonstrate the complexity of full FIM complexity by reducing hidden size until a model becomes practically feasible to train. We reduced the hidden dimensionality of our smaller convolutional model from 768 to 10 in order to compute the full FIM. An increase of this dimensionality to 20 makes the model too large to be held in memory for both the full FIM as well as the layer-wise independent block-diagonal FIM. Kronecker Factorization is therefore essential to reduce asymptotic memory requirements. Scaling to Large Models. We use two models to perform entity linking: 1) a CNN-based model and 2) a Transformer-based model. In order to implement Kronecker Factorization on the CNN, we must extend it to convolutional layers using the derivations in Grosse and Martens [2016] . A CNN can be thought of as a variation of a linear layer in which the linear layer is applied to multiple inputs in one forward pass of the model. In this case, the expectation is not only over each instance in the dataset, but also over each time the linear layer is used with different inputs. In order to take this expectation, we make the assumption that the gradients with respect to these inputs are independent even within a single instance. We also make this assumption for models with shared weights in different linear layers. Since the only layers in Transformer models are linear layers and layer norms, we decide to ignore layer norms in our EWC regularization term. Entity linking is a longstanding task in NLP that entails matching the mentions, spans of text in natural language, with the entities to which they refer in an ontology. In this paper, we assume that each instance consists of a trimmed down set of candidates, and a mention in its context. Specific to our motivating setting of bio-medical text, there exist a large set of diverse ontologies in the medical domain, and it is common to want to link entities in clinical texts to these. In such applications it is not uncommon to lose access to (sensitive) training datasets; but in many cases we would like to maintain model performance on the tasks that these represent. We use all possible permutations of the Medmention, MedNorm, and 3DNotes datasets to test how well different methods prevent CF. For details on the statistics, we refer to the appendix. [2019] is a publicly available corpus for medical concept normalization. It conrains 4,392 abstracts from PubMed 3 , with bio-medical entities annotated with Unified Medical Language System (UMLS) concepts. We follow the approach of candidate generation described in Murty et al. [2018] . We retain only the top nine most similar entities (excluding the ground truth entity) as negative candidates. In addition, the ground truth entity will be considered as the positive candidate, thus forming a set of 10 candidates for each mention. MedNorm Belousov et al. [2019] is a corpus of 27,979 descriptions mapped to two medical ontologies (MedDRA and SNOMED-CT), sourced from five publicly available datasets across biomedical and social media domains. We use a subset of these: CADEC Karimi et al. [2015] , TAC Demner-Fushman et al. [2018] , and TwiMed Alvaro et al. [2017] and employ a BM25 Robertson and Walker [1994] based retrieval approach to generate top entities for the mention, which will be candidates of the ranking model. 3DNotes. We also use a de-identified corpus of dictated doctor's notes (3DNotes), similar to . These are annotated with medical entities related to signs, symptoms, and diseases. These entities are mapped to the 10th version of International Statistical Classification of Diseases and related health problems (ICD-10), which is part of UMLS. The annotation guidelines are similar to the i2b2 challenge Uzuner et al. [2011] guidelines for the problem entity. We use 20 words on the left and right as mention context. We perform initial experiements using a downsampled dataset (comprising ten percent of the data) and one ordering over two datasets (MedMention, 3DNotes) to compare Learning Rate Control, Experience Replay, L 2 regularization, and EWC. Of these, we take EWC as the baseline from which to compare the proposed KFC extension. 4 Specifically, we compare traditional EWC (using the diagonal FIM) to one using Kronecker Factorization on the full data and all possible permutations of all three datasets. We select optimal λs using calculated from these different inputs paired with their outputs' gradients. In the case of RNNs, these different inputs are different time-steps. 3 https://www.ncbi.nlm.nih.gov/pmc/ 4 We do not perform further explicit comparisons to Experience Replay, as we are interested in approaches that do not assume access to data from all tasks. the development sets in each dataset, and report results on the test set for those λs. In Section A of the appendix, we further discuss the hyper-parameter search which can give a sense of the trade-off between low and high regularization coefficients. We perform all experiments using a CNN-based model with 26M parameters and a Transformer-based model with 46M parameters that is initialized using the first three layers of BERT Devlin et al. [2019] , training for 10 epochs on each dataset. We perform each tier of training for models using an NVIDIA V100 GPU with 16GB of memory. Our initial comparison of baseline methods using a subset of the full data (Table 1) shows that of the methods that do not require access to data from prior tasks (i.e., all save for Experience Replay), EWC performs the best. In all tables, entity linking accuracy is defined as the percentage of mentions where the correct candidate is ranked the highest. We then evaluate baseline individual models for each of the datasets shown in Table 2 . The BERT model has the highest accuracy across the board. This table also shows that the MedNorm dataset is the easiest data on which to perform well. For comparison, we also train our models on the combined data of all three datasets and show our results in Table 2 . We find similar performance to the individual training, which means that there does exist an almost optimal set of parameters that performs well across all three datasets. Generally speaking we expect and verify that training on tasks sequentially will decrease performance on prior tasks due to CF, and we measure this drop by calculating the difference in accuracy between our individually trained models in Table 2 and those after training on a second and third task, shown in the percentage change columns of Tables 3 and 4. Tables 3 and 4 show the robustness of Kronecker Factorization in preventing CF on prior tasks. In almost all permutations of the datasets for both CNN and BERT models, using the Kronecker Factorization of the FIM outperforms standard EWC. These results also show that neither method prevents the model from adapting to new tasks, a potential problem called intransigence Chaudhry et al. [2018] common for constraint-based methods. We avoid this problem by tuning the regularization strength λ (Section A). In order to visualize the benefits of estimating a FIM with non-zero off-diagonal terms and the degree to which the Kronecker Factorization can accurately estimate the Full Fisher in practice, we present Figure 2 . This shows that there are indeed large off-diagonal terms in the full FIM and some of those are estimated accurately by the Kronecker-Factored FIM. Kronecker-factored FIM Full FIM Overall, EWC mitigates CF when compared with the baseline of no regularization, and KFC performs better than EWC, without much hindrance of adaptation to the new task. However, there are cases where one or both methods perform worse than the baseline at CF mitigation. For instance, in BERT training for permutation MedNorm → MedMention, EWC actually does worse than the baseline, and KFC performs about equivalently to the baseline. In the BERT permutation 3DNotes → MedNorm → MedMention. On 3DNotes (Task A), KFC performs worse than the baseline and EWC performs equivalently to it. In both of the tier 3 training permutations, both regularization strategies perform worse than the benchline on Task A. These problems appear to occur mostly with BERT (and not with CNNs), and occur when MedNorm is the first task. MedNorm is the smallest dataset, so we expect CF to be the worst for it because the number of training steps taken on it is small compared to that which is taken on the other datasets after it. The fact that these problems occur mostly with BERT may be due to the larger number of parameters or more complex parameter inter-dependence resulting in a harder-to-approximate FIM, however, this is only a conjecture at present. We observe that in some cases, performance on the new task is completely unchanged or slightly improved. This occurs mostly when the new task is MedNorm because the comparative size makes it easier to adapt faster to the new data regardless of the regularization. In other words, the amount of parameter shift needed for this dataset is minimal. Taking a closer look at Figure 2 we can get a better idea of how Kronecker Factorization may be helping. The Kronecker-Factorized block-diagonal FIM shows the last two layers of the CNN architecture: 1) a 10 × 10 linear layer, and 2) a 10 × 1 linear layer. The first layer consists of 11 × 11 small blocks, some of which share commonalities. The first 10 × 10 blocks represent the interactions of one column of the weight matrix with another or itself, and the last row and column represent the interaction of the bias vector with each of the columns of the weight matrix and itself. Though there are some major differences between the Kronecker-Factored and Full FIMs, there is a pattern of similar small blocks that appear in both. In addition, the 6th and 7th rows and columns of small blocks have noticeably smaller magnitudes in both the Kronecker-Factored and Full FIMs. Figure 2 also reveals potential pitfalls of Kronecker Factorization. In the Full FIM, one of the more interesting features of each of these small blocks representing column-to-column interactions within the weight matrix is that they tend to have strong diagonals, meaning that elements in the same row of the weight matrix interact strongly. In addition, the blocks on the diagonal representing within-column interaction have elevated intensity, indicating that elements in the same column of the weight matrix also interact strongly. Neither of these phenomena appear in the Kronecker-Factored approximation. This provides clues into how we might further sparsify the intra-layer interactions, which may present an alternative or an improvement to Kronecker Factorization. Another feature of Figure 2 is the high magnitude of elements representing inter-layer interactions which are not modeled in the Kronecker-Factored FIM. This shows that there is still a significant gap between this approximation and using the full Fisher. Kronecker Factorization for Neural Networks. Martens and Grosse [2015] first introduce Kronecker Factorization as an approximation of blocks of the FIM of neural networks. They use this to perform second-order optimization techniques on linear neural networks, and then extend this to optimizing convolutional architectures Grosse and Martens [2016] . More recently, Ritter et al. [2018a] show that the Fisher approximation can be used as a posterior on network weights, and then expand EWC to use off-diagonal elements of the FIM in its regularization term with this approximation Ritter et al. [2018b] . This last paper mainly focuses on small vision datasets. Continuous Learning in NLP. NLP seems to be particularly susceptible to CF Howard and Ruder [2018] , . Recent work has therefore focused on developing continuous learning techniques in NLP to mitigate this issue Moeed et al. [2020] , Pilault et al. [2020] , . There has also been work that applies previous techniques to specific domains (e.g. machine translation Thompson et al. [2019] , sentiment analysis Madasu and Rao [2020] , and reading comprehension Xu et al. [2020] ). Many of these methods focus on avoiding over-fitting to new tasks during fine-tuning, whereas we focus on maintaining high performance on old tasks. None of these use Kronecker Factorization, which has not yet been scaled to prevent CF in large NLP models. CF mitigation is particularly important in clinical NLP given that many clinical datasets are quite different from generic domains and from each other. Arumae et al. [2020] explore CF in language modeling when transferring between the generic, clinical, and biomedical domains and compare learning rate control, experience replay, and EWC. We have demonstrated the effectiveness of Kroneker Factorization (KFC) for preventing catastrophic forgetting in modern large-scale neural architectures commonly used in NLP, improving on the results of Elastic Weight Consolidation. We showed that KFC can be used to create a unified model on multiple domains of Medical Entity Linking with good performance across tasks after a continuous (sequential) learning process. We highlighted strengths and weaknesses of the adopted approximation used in KFC, pointing to potential future directions. Future work might consider alternatives to the block diagonal structure on the covariance matrix used in KFC. Another promising line of inquiry concerns reducing the difficulty in selecting the λ hyper-parameter, which controls the strength of the regularization, without requiring a robust grid-search. The early experiments on the smaller amounts of data (see section 4) gave us a sense of the magnitude of optimal lambdas, so for the 2nd tier of training on the full data, we perform a grid search at the following lambdas for both the original diagonal EWC and the Kronecker Factored EWC: 1e1, 1e3, 1e5, 1e7, 1e9. The plots in the appendix demonstrate that the lowest of these lambdas correspond to almost no CF mitigation compared with no regularization and the highest of these correspond to almost no CF but very poor performance on the 2nd dataset. This means that these lambdas span the space well. For each of the 6 dataset permutations at the 2nd tier training, we pick the optimal lambda and further use this optimal lambda in the 3rd tier of training. In the 3rd tier of training, there are actually two hyper-parameters to pick for the regularization terms corresponding to each of the first two datasets. Because the first dataset had a lambda tuned for it during the 2nd tier training, we use that best lambda during the 3rd tier training for the regularization term corresponding to the first dataset. For the lambda corresponding to the regularization term for the second dataset, we average the best lambdas for 2nd tier training permutations that started with that dataset, hypothesizing that this is a good approximation of optimal lambda for this term. We then perform a small grid search around each of these values, testing the values 100 times less and 100 times greater than the estimated value and pick the best of these as optimal. In picking the optimal lambdas for the 2nd and 3rd tiers of training, we choose the lambdas that have the least amount of drop in accuracy on the regularization term's corresponding dataset without too much drop in accuracy on the current training's dataset. We give optimal lambdas in Table 6 . Twimed: Twitter and pubmed comparable corpus of drugs, diseases, symptoms, and their relations. JMIR public health and surveillance An empirical investigation towards efficient multi-domain language model pre-training MedNorm: A corpus and embeddings for cross-terminology medical concept normalisation Multitask learning. Machine learning Riemannian walk for incremental learning: Understanding forgetting and intransigence Efficient lifelong learning with a-GEM Recall and learn: Fine-tuning deep pretrained language models with less forgetting Episodic memory in lifelong language learning A dataset of 200 structured product labels annotated for adverse drug reactions. Scientific Data BERT: Pre-training of deep bidirectional transformers for language understanding Science from Fisher Information: A Unification. Cambridge University Press A kronecker-factored approximate fisher matrix for convolution layers Clinical characteristics of coronavirus disease 2019 in china Universal language model fine-tuning for text classification Cadec: A corpus of adverse drug event annotations Overcoming catastrophic forgetting in neural networks A robustly optimized bert pretraining approach Gradient episodic memory for continual learning Sequential domain adaptation through elastic weight consolidation for sentiment analysis Optimizing neural networks with kronecker-factored approximate curvature Catastrophic interference in connectionist networks: The sequential learning problem An evaluation of progressive neural networks for transfer learning in natural language processing Medmentions: A large biomedical corpus annotated with UMLS concepts. CoRR, abs/1902.09476 Hierarchical losses and new resources for fine-grained entity typing and linking Conditionally adaptive multi-task learning: Improving transfer learning in nlp using fewer parameters & less data Language models are unsupervised multitask learners Connectionist models of recognition memory: constraints imposed by learning and forgetting functions A scalable laplace approximation for neural networks Online structured laplace approximations for overcoming catastrophic forgetting Some simple effective approximations to the 2-poisson model for probabilistic weighted retrieval Progress & compress: A scalable framework for continual learning Monte carlo computation of the fisher information matrix in nonstandard settings Memorybased parameter adaptation Overcoming catastrophic forgetting during domain adaptation of neural machine translation i2b2/VA challenge on concepts, assertions, and relations in clinical text Attention is all you need Sentence embedding alignment for lifelong relation extraction Forget me not: Reducing catastrophic forgetting for domain adaptation in reading comprehension Learning and evaluating general linguistic intelligence. ArXiv, abs/1901.11373 Continual learning through synaptic intelligence Latte: Latent type modeling for biomedical entity linking