key: cord-0593848-76geqklh authors: Chen, Bingyang; Chen, Tao; Zeng, Xingjie; Zhang, Weishan; Lu, Qinghua; Hou, Zhaoxiang; Zhou, Jiehan; Helal, Sumi title: Feature-context driven Federated Meta-Learning for Rare Disease Prediction date: 2021-12-29 journal: nan DOI: nan sha: bc6f4c27587288ec2c3bcce80fd9df6105926960 doc_id: 593848 cord_uid: 76geqklh Millions of patients suffer from rare diseases around the world. However, the samples of rare diseases are much smaller than those of common diseases. In addition, due to the sensitivity of medical data, hospitals are usually reluctant to share patient information for data fusion citing privacy concerns. These challenges make it difficult for traditional AI models to extract rare disease features for the purpose of disease prediction. In this paper, we overcome this limitation by proposing a novel approach for rare disease prediction based on federated meta-learning. To improve the prediction accuracy of rare diseases, we design an attention-based meta-learning (ATML) approach which dynamically adjusts the attention to different tasks according to the measured training effect of base learners. Additionally, a dynamic-weight based fusion strategy is proposed to further improve the accuracy of federated learning, which dynamically selects clients based on the accuracy of each local model. Experiments show that with as few as five shots, our approach out-performs the original federated meta-learning algorithm in accuracy and speed. Compared with each hospital's local model, the proposed model's average prediction accuracy increased by 13.28%. ccurate clinical prediction models can help the clinical decision makers identify the potential risk at an early stage so that appropriate actions can be taken in time in the pathway of care delivery. However, traditional disease prediction models are incapable of ensuring prediction accuracy with only a small amount of medical data, such as in the case of a rare disease. This problem is significant as there are about 7,000 rare diseases afflicting millions of people worldwide [1] . Furthermore, because of data privacy, hospitals and patients are still reluctant to share their health information. This leaves only a small sample of such data available for potentially predicting such rare diseases under the constraint that each hospital stores its diagnostic records at its own data vault (server) as shown in Fig.1 . This scarce and constrained environment makes it difficult to predict rare diseases using neural networks which requires a lot more data for training. There are two main approaches of disease prediction: 1) time-series regression prediction [2] [3] [4] and 2) classifier-based prediction [5] [6] [7] . The first approach models and analyses the temporal dependencies between medical features to predict the risk probability of a disease. Time-series regression prediction mainly maps medical features into N-dimensional vectors by leveraging the encoding in natural language processing. Then a time series network is utilized to compute corresponding hidden states from each input medical feature to make predictions [2] . Classifier-based prediction, however, trains a classifier by a supervised classification system trained on medical datasets. However, disease prediction often requires large amounts of medical data to train such models. A data feature fusion approach [5] is proposed to improve heart disease categories prediction, which enriches data features by fusing electronic medical data with sensor data. This approach uses only multi-source data fusion in a single hospital to predict diseases, which does not fundamentally address the problem of multiple small samples and data privacy. In addition, the imbalance in disease categories is common in most medical scenarios [6] . This leads to low accuracy prediction results of rare diseases. In short, there are two main challenges for machine learning approaches when applied to rare disease prediction: 1) it is difficult to ensure the accuracy and generalizability of the model for unseen diseases, due to class imbalance. 2) lack of sufficient medical data given the rarity of the diseases and the privacy concerns which makes it very hard to train models. The complexity of medical data negatively affects disease identification. However, human physicians can recognize the rare disease through clinical experience. A model can also distinguish unseen classes if it is capable of "learning to learn", which is known as meta-learning. Specifically, MAML (Model Agnostic Meta-Learning) [8] algorithms can train the network with small samples of common categories to recognize other unseen categories. Using the method of modal movement rather than data movement, federated learning [9] realizes data fusion while protecting information security and privacy to further improve model performance. Integrating meta-learning and federated learning, results in a federated meta-learning framework [10] which effectively improves the model's classification accuracy of unseen classes. To address all these challenges, and inspired by the framework in [10] described above, we utilize federated metalearning to predict rare diseases, and apply two additional techniques to bolster its performance. We propose a novel dynamic-weight based fusion strategy and an attention mechanism based on federated meta-learning (DWA-FML) for rare disease prediction. In particular, we propose an attention-based meta-learning (ATML) approach based on the MAML architecture, which pays more attention to tasks with weaker training effects for improving identification results of unseen diseases. Moreover, to achieve effective data fusion while preserving data privacy, we designed a dynamic-weight based fusion strategy within federated learning to enable the meta-learning model to learn more categories of common disease characteristics to further improve the prediction accuracy of rare diseases. Specifically, each hospital (client) evaluates its local model and uploads it to a local server only when the newly evaluated model performs better than the previous version of the global model. After local models are uploaded to each hospital's local server, a central server, external to the hospitals, accesses the local servers to calculate the accuracy of each local model to use as the weight to update the global model. Data privacy is completely preserved in this approach where sensitive local medical data are not exposed or made accessible to the central server. Through experimentation, results show that our proposed approach achieves higher accuracy in less time (requiring only a few shots) than that of the original federated meta-learning approach. In fact, ATML achieves significantly higher performance than that of other meta-learning models on both medical image dataset as well as our dataset. The primary contributions of this paper are as follows: 1) A novel machine learning approach for rare disease prediction. To the best of our knowledge, this is the first work to study rare disease prediction based on federated metalearning. 2) An attention-based meta-learning approach to improve the prediction model performance for rare diseases. Our approach, as we show through experimental evaluations, performs better than that of other meta-learning approaches. 3) A dynamic-weight based fusion strategy to enable the privacy-preserving federation of the meta-learning architecture, which meets participating hospitals' data privacy policies. 4) A comprehensive and detailed evaluation of the proposed approach using rare disease datasets from actual hospitals. Results show that our federated meta-learning model outperform the original federated meta-learning in terms of accuracy and speed of prediction. The remainder of the paper is organized as follows. Section II discusses related work. Section III presents the approach for rare disease prediction in detail. Section IV evaluates the approach. Section V concludes the paper and points out important future work. Meta-learning for disease prediction. Meta-learning was proposed to address the weak adaptation of traditional neural networks to new tasks [11] , which includes metric learning [12] , [13] , and model-agnostic meta-learning (MAML) [8] . The former had been utilized to learn the kernel-based regularization approach for blood glucose prediction [14] . The second approach can identify unseen classes with only a few samples of common classes and has shown promising results [15] , which may be applied in medical data analysis. Specifically, in training, a support set consisting of a small number of samples from different classes is used to train the network. Similarly, a query set is utilized to evaluate the loss of the support set. During the test, the meta-learner can quickly adapt to new tasks with few fine-tuning steps. Due to limited medical samples, meta-learning has been employed for clinical risk prediction [7] . It proposes a model-agnostic gradient descent framework to train a meta-learner with a set of prediction tasks where the target clinical risks are highly relevant. However, on the one hand, we are usually not aware of potential relationships between source and target domains. On the other hand, there are little data in the source domain and even we cannot find the source domain. The ability to identify unseen classes is also used for rare skin disease diagnosis [16] , which utilize image argumentation techniques and a difficulty-aware meta-learning algorithm to scale the loss for improving disease classification performance. The idea of a difficulty-aware meta-learning algorithm is worth learning and can be further improved. Federated learning for disease diagnosis. Federated learning is a mechanism of training a shared global model on a central server, using only updates of local models stored and maintained in a federation of local sites such as hospitals or other institutions. Sensitive data in local institutions are used to perform model learning [17] but are not accessible by or visible to the global model and its central server. Specifically, each hospital (client) uses its own clinical data to train its local model without any need for data sharing between hospitals. Federated learning has in fact become the most effective approach to privacy-preserving data fusion in health care [18] . The trained local models are variably aggregated in a central server according to each hospital's contribution to the global model update, which is based on the size of data used in updating the hospital's local model [9] . The updated model is then dispatched to each client for the next round of training. A federated uncertainty-aware learning method is proposed to improve preterm birth prediction results, which reduces the contribution of models with high uncertainty in the aggregated model [19] , but the best uncertainty evaluation criterion is difficult to determine. A dynamic fusion strategy is applied in federated learning for COVID-19 detection [20] . The client that performs better than the previous version is selected for average fusion in each episode. However, the average fusion method may not be the most effective fusion approach. Federated Meta-Learning. The framework of federated meta-learning is first proposed by Huawei in 2018 [10] , which focuses on improving the transmitting parameterized algorithm between mobile devices and central servers for faster convergence and higher accuracy. Different from federated learning, each client uses the parameters distributed by the server to initialize and train the local model. Local models of different clients are uploaded to a central server for updating the global model. Considering federated learning only focuses on a common output for all the users, MIT [21] proposes a theoretical understating regarding federated metalearning for personalized study. Due to data privacy and scarce of credit frauds samples, federated meta-learning has been used to detect credit card fraud with strong detection performance [22] . However, the metric learning effects of federated meta-learning may have discrepancies in different tasks. In this section, we elaborate on the proposed federated metalearning-based approach for rare disease prediction. We first present attention-based meta-learning (ATML) for rare disease identification. Then we describe how a dynamic-weight based fusion strategy integrates ATML model in federated metalearning. Each client uses the meta-learning approach to train models, thus local meta-models are used in the following description to refer to local models. The system model of the proposed federated meta-learning approach is shown in Fig.2 TABLE I summarize the strong points of both federated learning and meta-learning, and it is obvious that combining both of them is very attractive to resolve the previous mentioned problems. In federated disease prediction, suppose there are hospitals ℎ ( = 1,2, … , ) as clients. At each iteration step ( = 1,2, … , ) on the server, the local meta-models θ ℎ from each hospital are aggregated to update the global model . Let the accuracy of the global model and the local metamodel on the test data be and ℎ respectively in the th iteration step. The learning rate is defined as to control the training speed of a model. Each hospital ℎ employs the meta-learning approach for rare disease prediction on their own private medical dataset = { ℎ , ℎ }(ℎ = 1,2, … ). Where ℎ is disease sample and ℎ is the corresponding label. In the meta-model training, tasks are randomly sampled in the set of common diseases, consisting of a support set and a query set . Each local meta-model is initialized with parameters from the server then trained with . Loss � ( ℎ ), ℎ � is evaluated by the error of base-model (base-learner in meta-learning). The learning rate of the inner loop is defined as and outer loop as .In meta-testing, rare diseases as test sets , also consisting of support sets and query sets . The support set is utilized for fine-tuning the trained local meta-model θ ℎ to assess its performance on the query set . Hence, we exploit federated meta-learning to predict the unseen disease with the learning experience of common diseases, which can improve model classification accuracy while protecting data privacy. The specific process of our approach is shown in Algorithm 1, which can be summarized in seven steps. Step1: In federated learning, at each iteration step, the local meta-model(hospital) with better accuracy than the global model(server) of the previous round is dynamically selected and uploaded to the server. Step2: During model fusion, dynamic-weight fusion is performed according to the accuracy of the uploaded local meta-model to update the global model. (line 9-11). Step3: The parameters are downloaded to each hospital from the central server for model initialization to train specific tasks. Step4: Instead of cross entropy loss, the focal loss is utilized for evaluating each task error in the inner loop. Step5: _ are employed for updating parameters in the outer loop (line 20-21). Step6: The trained local meta-model of each hospital is prepared for uploading to the central server. Step7: Return to the first step until model convergence. Return ℎ ′ to server B. Attention-based meta-learning for disease prediction MAML [8] model pays equal attention to all tasks. However, in practice, it is common that good classification accuracy can be achieved for simple tasks while poor accuracy for difficult tasks. To address this challenge, in each hospital ℎ , we propose an attention-based meta-learning approach as shown in Algorithm 1. A series of tasks are randomly sampled in each episode with a task distribution ( ) manner, which consist of support set and query set .The initialization parameters is dispatched to each hospital(client) by the server. Then is updated to ℎ by gradient descent on support set in "adaption steps" (inner loop). We used focal loss [23] _ � � ℎ �, ℎ � to improve the learning ability of the model. The focal loss function (Eq.1, Eq.2) and gradient descent (Eq.3) are defined as: Where is a balanced variant of focal loss, is focusing parameter satisfying ≥0. ℎ denotes the base-model trained on . In the "evaluation steps" (outer loop), the base-model ℎ is evaluated on query set . Accordingly, the loss is computed to update the base-model. We propose that lower accuracy tasks (complex tasks) should contribute more to losses, an _ function is designed to further improve the accuracy of rare disease prediction. The attention-based meta optimization function instantiated as follows: Where φ is a scaling factor to regulate the model's attention to difficult tasks. We utilize Transformer (two layers are used) as the main network architecture of ATML (attention-based metalearning). Specifically, the encoder of Transformer is employed for feature extraction to a m-dimensional vector . To alleviate the negative impact of excessive differences between features, batch normalization is utilized to convert original data into b-dimension vector . Then the vector and are concatenated into a + dimensional vector to feed into the fully connected layer. As fewer training classes available will lead to weak sample diversity, we apply an Adam optimizer and set weight decay to 0.1 to improve the test performance of meta-model. The learning rate is defined as 0.001 to perform disease classification. In federated meta-learning, we aim that each hospital (client) can "share" their private dataset to design an effective rare disease prediction solution by increasing the categories of common diseases. Each hospital ℎ downloads parameter from the server for initialization, and trains the model on support set , and consequently tests the base-model ℎ on the query set to update the parameters from ℎ to ℎ ′ . After that it sends the meta-model ℎ ′ to the server. The server aggregates different local meta-models to update the global model and dispatch the novel initialization parameters for each hospital. A dynamic-weight based fusion strategy is proposed to improve the prediction result. In particular, all the local meta-models are uploaded and average fused at the first round of aggregation. In each of the following rounds, each local meta-model accuracy ℎ on the test set are compared with the global model −1 of the previous round. If local meta-models perform better, they are selected to prepare for aggregation, otherwise, the corresponding hospitals are filtered out at this round. In addition, a fusion weight is utilized in the aggregation process. Specifically, is the ratio of the current model's accuracy ℎ �ℎ ∈ℎ � to the sum of all selected local meta-models' accuracy ∑ ℎ =1 . The details of rare disease prediction model training process are illustrated in Fig.2 and described in Algorithm 1. In this section, we evaluate the performance of different models on Arrhythmia dataset [24] . First, we compare with other related models to show the effectiveness of the proposed approach. Then, we make ablation experiments to explore which part of the improvement contributes most to the predicted results. We will discuss each meta-model effects with different iteration steps. Left bundle branch block 9 Arrhythmia dataset [24] contains 279 dimensional features. We use nine classes of diseases as shown in Table II . Additionally, medical image dataset [25] is also used to demonstrate the effectiveness of our approach. We consider the first five classes with large samples as common diseases (meta-train dataset ), and other classes as rare diseases (meta-test dataset ). Each task is randomly sampled as distribution ( ) from and . In metatraining, each task is a binary classification consisting of two random classes with samples per class in . The query set samples are twice the support set samples. Experiments showed that the model converges when iterates 1300 rounds, so we set 1500 rounds as the number of iterations and applied the attention-based optimization approach. In particular, we set the parameters that =5, = 2, = 2 via experiments. The number of tasks is 10 and iteration per adaption step is 5. Similarly, in the meta-testing, two classes are randomly selected from the meta-test dataset , and K samples of each class are sampled. The final experimental results are the average accuracy over 30 runs. In addition, experiments are also performed using the medical image dataset to demonstrate _ effectiveness in the metalearning, as evaluated with AUC like Li [16] did. We compared some meta-learning models including Relation Net, MAML, DAML. As shown in TABLE III, in all sample settings, the proposed attention-based meta-learning approach has better performance than that of other three approaches on both datasets, especially for Relation Net. Due to the limited number of training classes, the metric-learning model cannot sufficiently extract diverse features, which is one of the most important reasons for its inferior results. In addition, our approach exploits the MAML architecture and improves the _ function of DAML, which indicates the effectiveness of attention-based meta optimization in rare disease prediction. Some widely used algorithms are selected for comparison experiments. Since there were no fine-tuning processes, baseline models were directly trained on K shots per rare disease classes. Both the LSTM and the Transformer have two layers, while MLP consists of four linear layers. The iteration steps are five and other settings as the same as the meta-learning in Section IV. As shown in TABLE IV, the Transformer has stronger representation capability of feature extraction and does achieve good performance. However, compared with ATML, these classical algorithms have no capability to identify the rare disease. In addition, ATML with one shot is much better than others with five-shot, indicating the good performance of our approach for rare disease prediction. In practical scenarios, it is impossible for a hospital to possess all disease categories for model training. To simulate the actual rare disease predicted situation in each hospital, we randomly selected three of the original five common disease classes as a training support set. Specifically, each local metamodel is trained by the three classes of common diseases to predict four classes of rare disease, which is denoted as ATML3 to make a fair comparison with DWA-FML. In addition, classical federated learning with average fusion is denoted as FedAvg for a comparison experiment, where the Transformer is utilized for feature extractor. Similarly, the integration of FedAvg with MAML is denoted as FedAvg (MAML), and the integration of FedAvg with ATML3 is denoted as FedAvg (ATML3). As shown in As shown in TABLE VI, DWA-FML achieves a significantly higher accuracy than that of other methods. FedAvg is trained with five classes, however, its results are even worse than that of ATML3 trained with three classes. The reason is that the Transformer has no capability to identify unseen classes. The results of ATML are much better than that of MAML in the experiment of section IV B(a), while the performance of FedAvg(MAML) is a little better than that of ATML3. The comparison result shows it is essential for the meta-model to be trained in more common disease categories. Compared with FedAvg(ATML3), DWA-FML is more effective using the dynamic-weight based fusion strategy. We now explore the impact of relevant parameters on the DWA-FML. Fig. 3 shows the local meta-model prediction results of each hospital with fusion rounds. The accuracy continues to rise for the first 150 steps, then it becomes gradually stable. Since each meta-model needs to be iterated for five rounds locally, our model only needs 750 rounds to complete the convergence. Compared with the meta-learning method (without federated learning) that converged in 1300 rounds, our model is able to learn faster. It indicates the proposed federated learning improves the prediction accuracy of rare diseases and accelerates model training. In addition, all hospitals achieve the best results almost simultaneously, which further indicates the effectiveness and generalizability of our approach. As shown in TABLE VII, the performance of FedAcc is slightly better than that of ATML3. This indicates the effectiveness of our fusion strategy. We further explore the effects of fusion strategy on DWA-FML. The accuracy of FedAcc(MAML) is 2.14% higher than that of FedAvg(MAML), and DWA-FML is 1.57% higher than that of FedAvg(ATML3). The accuracy of FedAvg(ATML3) is 4.09% higher than that of FedAvg(MAML) and DWA-FML is 3.52% higher than that of FedAcc(MAML). Therefore, the fusion strategy improves the prediction performance, the _ function is considered make a more contribution to the proposed approach. Due to the necessity of fast adaption to new tasks, we set fine-tuning steps from one to ten in our experiments. As shown in Fig. 4 , the performance increases rapidly in the first five fine-tuning steps and gradually levels in the next steps. The experiment results further illustrate our approach has good potential for rare disease prediction. 5 shows that most hospitals achieve the best results at the fifth training step. However, the models' performance degrades in the next training steps due to multiple training on common disease may lead to model over-fitting. In addition, compared with other clients, Hospital2 requires more training steps (eight steps) to achieve the best results, implying that the prediction performance of the meta-model in Hospital2 is weak. Furthermore, we find that the Hospital4's prediction accuracy is higher than that of other hospitals in both the finetuning (Fig. 4) and training process, indicating its meta-model is easier to train and has strong stability. Due to the use of dynamic-weight based fusion strategy, we believe that the local meta-model of the fourth hospital contributes more to the model aggregation of the server. As shown in Table VIII , the time consumption of federated learning model is much less than federated meta-learning. Because each local model (Transformer) is directly trained on rare diseases and then used to predict rare diseases. This process can be seen as the testing part of meta-learning. Due to only the difference in loss function, ATML and MAML have similar time consumption, which is in line with our expectations. In addition, our proposed fusion strategy reduced the time consumption by comparing FedAcc(MAML) and FedAvg(MAML), DWA-FML and FedAvg(MAML) respectively. Each local meta-model is only uploaded when it performs better than the global model of the previous version. The reduced time consumption is the model upload time. The small differences in time consumption between FedAvg(Transformer) and FedAcc(Transformer) indicate the global model's accuracy is too weak, and the local model is uploaded almost every round. In general, the proposed approach reduces the time consumption significantly compared with that of original federated meta-learning. This paper proposes a novel and effective rare disease prediction approach based on federated meta-learning. First, we present a novel federated meta-learning-based approach for rare disease prediction. Second, we propose an attention-based meta-learning approach for enhancing the model attention to difficult tasks. Third, we design a dynamic-weight based fusion strategy for each client to decide the participation of the local meta-model according to its performance. The evaluation results show that the proposed approach achieves higher classification accuracy across hospitals with good performance compared to the counter parts. In addition, our model achieves better performance and less time consumption than that of the original federated meta-learning. In future work, we will explore the communication efficiency of our approach and consider investigating how to use blockchain technology to further improve data security and privacy. RDAD: a machine learning system to support phenotype-based rare disease diagnosis Using recurrent neural network models for early detection of heart failure onset Time series forecasting of COVID-19 transmission in Canada using LSTM networks Evaluation of deep learning approaches for identification of different corona-virus species and time series prediction A smart healthcare monitoring system for heart disease prediction based on ensemble deep learning and feature fusion Efficient treatment of outliers and class imbalance for diabetes prediction Metapred: Meta-learning for clinical risk prediction with limited patient electronic health records Model-agnostic meta-learning for fast adaptation of deep networks Federated learning of deep networks using model averaging Federated meta-learning with fast convergence and efficient communication Matching networks for one shot learning Learning to compare: Relation network for few-shot learning A meta-learning approach to the regularized learning-Case study: Blood glucose prediction Meta-sgd: Learning to learn quickly for few-shot learning Difficulty-aware meta-learning for rare disease diagnosis Federated learning for healthcare informatics The future of digital health with federated learning Federated uncertainty-aware learning for distributed hospital EHR data Dynamic fusion-based federated learning for COVID-19 detection Personalized federated learning: A meta-learning approach Federated Meta-Learning for Fraudulent Credit Card Detection Focal loss for dense object detection The impact of the MIT-BIH Arrhythmia Database Rare diseases, genomics and public health: an expanding intersection