key: cord-0699333-0s5ze3w7 authors: Sun, Chenxi; Dui, Hongna; Li, Hongyan title: Interpretable time-aware and co-occurrence-aware network for medical prediction date: 2021-11-02 journal: BMC Med Inform Decis Mak DOI: 10.1186/s12911-021-01662-z sha: ff53c873ea6818fc0cfaea75dca9f4a6d23b13d7 doc_id: 699333 cord_uid: 0s5ze3w7 BACKGROUND: Disease prediction based on electronic health records (EHRs) is essential for personalized healthcare. But it’s hard due to the special data structure and the interpretability requirement of methods. The structure of EHR is hierarchical: each patient has a sequence of admissions, and each admission has some co-occurrence diagnoses. However, the existing methods only partially model these characteristics and lack the interpretation for non-specialists. METHODS: This work proposes a time-aware and co-occurrence-aware deep learning network (TCoN), which is not only suitable for EHR data structure but also interpretable: the co-occurrence-aware self-attention (CS-attention) mechanism and time-aware gated recurrent unit (T-GRU) can model multilevel relations; the interpretation path and the diagnosis graph can make the result interpretable. RESULTS: The method is tested on a real-world dataset for mortality prediction, readmission prediction, disease prediction, and next diagnoses prediction. Experimental results show that TCoN is better than baselines with 2.01% higher accuracy. Meanwhile, the method can give the interpretation of causal relationships and the diagnosis graph of each patient. CONCLUSIONS: This work proposes a novel model—TCoN. It is an interpretable and effective deep learning method, that can model the hierarchical medical structure and predict medical events. The experiments show that it outperforms all state-of-the-art methods. Future work can apply the graph embedding technology based on more knowledge data such as doctor notes. Electronic Health Records (EHRs) are increasingly popular and widely used in hospitals for better healthcare management. A typical EHR dataset consists of much patient information, including demographic information and medical information. The medical information is an irregular hierarchical patient-visit-code (patientadmission-diagnosis) form, shown in Fig. 1a: (1) Each patient has many visit records as he/she may go to see a doctor many times. The visit records have corresponding time stamps and form a sequence; (2) Each visit contains many codes, which are usually disease diagnoses. The codes have the co-occurrence relation without order. For example, in a patient record, the chronic kidney disease is recorded after a cold record, but we can't conclude that the patient didn't have chronic kidney disease before he caught a cold. Two diagnoses have an uncertain time Open Access *Correspondence: sun_chenxi@pku.edu.cn 1 School of Electronics Engineering and Computer Science, Peking University, No. 5 Yiheyuan Road, Beijing 100871, People's Republic of China Full list of author information is available at the end of the article relation. We call such issues as the co-occurrence relation, such as complication, causation, and continuity. Thus, EHR has both the time relation and the co-occurrence relation. Medical tasks such as disease prediction [1] [2] [3] , concept representation [4, 5] , and patient typing [6] [7] [8] are essential for personalized healthcare and medical research. Nevertheless, the tasks are challenging for physicians, considering the complex patient states, the amount of diagnosis, and the real-time requirement. Thus, a datadriven approach by learning from large accessible EHRs is the desiderata. In recent years, the Deep Learning (DL) model has made remarkable achievements due to its strong learning ability and flexible architecture [9] [10] [11] [12] [13] : some DL methods can model the sequential time relation of medical data. For example, RETAIN [3] utilizes gated recurrent unit (GRU) [14, 15] to predict medical events, Dipole [1] uses Bidirectional RNN (BRNN) [16] to integrate the information in the past and the feature, and T-LSTM [8, 17] injects the time decay effect to handle irregular time intervals. Using these methods, the EHR structure is modeled as Fig. 1b ; Some DL methods can model the cooccurrence relation of medical data. For example, Word-2Vec [18, 19] , Med2Vec [4] , and MiME [5] model the medical relations to better express the original data by the idea of representation learning [20] [21] [22] [23] . Using these methods, the EHR structure is modeled as Fig. 1c . However, no method can model both relations simultaneously. Because t there is a conflict between the two relations: The time relation makes data distributed longitudinally but the co-occurrence relation makes data distributed bipartite graph-like. If considering both these two relations, the EHR structure is shown in Fig. 1d . Meanwhile, in the real-world application, the datadriven method is required to be interpretable to facilitate the use of doctors [24] [25] [26] . However, the DL method is the black-box model which is troubled by poor interpretability [27] [28] [29] [30] [31] [32] . To address the above issues, in this work, we define EHR as the hierarchical co-occurrence sequence and propose a novel model called Time-aware and Co-occurrence-aware Network (TCoN). TCoN can not only model the two relations simultaneously but also has the ability of interpretation. TCoN has the pre-train and fine-tune mechanism for the imbalanced data and is more accurate than all baselines in medical prediction tasks. In this section, we first introduce the MIMIC-III dataset and the data preprocessing process. Then, we describe the proposed methods in detail. MIMIC-III is a freely accessible de-identified medical dataset, developed and maintained by the Massachusetts Institute of Technology Laboratory for Computational Physiology [33] . Based on MIMIC-III dataset, we selectively extract data and form three data sets: We extract records with more than one visit from MIMIC-III. The new dataset comprises 19,993 hospital admissions of 7537 patients and 260,326 diagnoses with 4,893 unique codes defined by the International Classification of Diseases-9 version (ICD-9). For one patient, the visit number is 2.66 on average. For one visit, the code number is 13.02 on average and up to 39. Following the latest sepsis 3.0 definition [34] , we extract 1232 sepsis patients whose SOFA is greater than or equal to 2. According to ICD-9 code, we extract 1608 heart failure patients who have diagnoses of 428.x code. In sepsis dataset and heart failure dataset, the extracted data is the records for the first time that these two diagnoses appear. And these two datasets are imbalanced. The detailed statistic is shown in Table 1 . Fig. 1 The data structure of EHR based on different methods. a Original EHR data structure. b EHR data structure based on time relation. c EHR data structure based on co-occurrence relation. d Data relation under our TCoN model. The data form b arranges codes in a random order, but different sequences have different effects on results. For example, the sequence 'heart disease -> influenza -> coronary' has closer relation between 'heart disease' and 'influenza' than the sequence 'heart disease -> coronary -> influenza' . The data form c can make every two codes have the equal relation, but if 'heart disease' , 'atrial fibrillation' and 'diabetes' are in three different visits, the equal relation will fail as there are different time intervals among them. The data form d is the combination. Meanwhile, the demographic information I is recorded to patients P. Definition 2 (Medical prediction tasks) They use a set of medical records R to predict the specific target Y = y 1 , y 2 , . . . y n . If n = 2 , it is a two-classification task. If n > 2 , it is a multi-classification task. The prediction task is f p : R → Y . Interpretation uses the correlations R of medical pairs Task 1 (Mortality prediction). To predict if the patient will die during the hospitalization. Task 2 (Readmission prediction). To predict if the patient will be hospitalized again. Task 3 (Disease prediction). Two disease prediction tasks: Sepsis and heart failure. Early diagnose is critical for improving patients' outcome [35] . Task 4 (Next diagnoses prediction). To predict the diagnoses of the patient in the next admission. Note that Task 1, 2, 3 are binary classification tasks and Task 3 is a multi-classification task. (1) The area under the curve of Precision (P) and Recall (R). It is a better measure for imbalanced data [36] . Evaluation 3 (Accuracy@k). The probability of the positive predictions in top-k prediction values. It is the evaluation metric of multi-classification tasks. As shown in Fig. 2 , our TCoN model contains the code block and the visit block: The code block is implemented by Co-occurrence-aware Self-attention (CS-attention); The visit block is implemented by Time-aware Gated Recurrent Unit (T-GRU); Two blocks are connected by Attention connection. Self-attention [32] in natural language processing considers the semantic and grammatical relations between different words in sentences. For each input, it has three vectors, Query (Q), Key (K), and Value (V). The multihead self-attention is designed as: In this work, we redesign the self-attention as CSattention (Eq. 5) to deal with the relations of EHR codes. CS-attention has two different heads-Local Head and Global Head. The local head learns the co-occurrence relations between every two codes in the same visit. A code is affected by the other codes equally. The global head learns the co-occurrence relations between every two codes in different visits. A code has different effects from the other codes according to different time intervals between visits. These two types of heads can learn a new representation C of each code C by its neighbors C nb . C is the original matrix of input codes. C is the new representation matrix. Q i , K i , V i is same as Eq. (4). i = 1 represents the local head and i = 2 represents the global head. d k is the dimension for Q and K . T is the time decay function g(�t) in Eq. 4. Both the number of local head and global head can be change. As shown in Eq. 3, T-GRU comprises an update gate z t and a reset gate r t . They control the extent to which the previous state h t−1 is brought into the current state h t and how far the previous state is brought into the current candidate state h t . For modeling the time irregularity, we build a time gate d t . This gate takes time interval into account and control delivered information from the previous visit to the current visit by time decay function g(�t) . The time decay function can determine how much the history state can be injected into the current unit. In Eq. 3, x t is the current input data, W , U , b are parameters. The output is the current state h. We propose three time decay functions (Eq. 7). Δt is the time interval between two visits, α is the decay rate. When α = 1 , the exponential form is more suitable for the small elapsed time, the logarithmic form is more suitable for the large elapsed time, and the reciprocal form is a compromise. Between code block and visit block, we design the connection method (Eq. 8). Where X vi is the ith input of visit v , C i is the output matrix with each row for one ith visit's code, W β is a parameter vector. When we (6) Time gate : Reciprocal form g(�t) = 1 1 + α�t Logarithmic form g(�t) = 1 log(e + α�t) Exponential form g(�t) = e −α�t Fig. 2 TCoN structure consider the demographic information I . The input will be a concatenation form: X vi = concate β T C i , I i . Besides, we propose a method to interpret TCoN. It is achieved by the correlation values among codes, visits, and predictions. It is based on the correlations R , containing two correlations: The code-code correlation is obtained from α of CS-attention. α ij means the effect of code j on code i , and large α ij means that code j could be the cause, complication, or early symptoms of code i ; The code-visit correlation is obtained from ∼ β of the Attention connection. Larger ∼ β means the closer relation. The interpretation path is a code sequence obtained by the reverse lookup starting with the prediction results. For a prediction P , the last visit is v n . In v n , we find the code c ni that contributed the most to v n according to ∼ β . For c ni , we find the closest code c (n−1)i in visit v n−1 according to the largest α * C ni . Similarly, we find c (n−2)i , c (n−3)i , . . . c 1i . So far, we find a path c 1i → · · · → c ni → P . This path can be described: a disease c 1i most likely infers c 2i , then c 2i most likely infers c 3i , … and c (n−1)i most likely infers c ni , finally, c ni most likely causes P. Finally, we apply a training method that enables TCoN to handle imbalanced data [37, 38] . In the pre-train process, we apply an auto-encoder network f ae with a minimum loss (Eq. 9) for the unsupervised representation learning task. In the fine-tune process, we use parameters of the encoder layer as the initial parameters of TCoN when training by the prediction objective in Eq. (10) . For TCoN, the input layer is represented by Eq. (8), Skipconnection is Eq. (12), layer normalization [29] is Eq. (13), and feed forward layer is Eq. (14) . The self-attention-based algorithm is parallel, but the RNN-based algorithm is serial [32] . TCoN has both structures and they are connected in series. Thus, the com- d is the representation dimension and n is the sequence length. O n 2 · d is the complex of CS-attention with n 2 for operations of every two inputs. O n · d 2 is the complex of T-GRU with d 2 for sequential operation. In our data, the dimensionality d is smaller than the data length n , so that the complex of TCoN is O n · d 2 . For data, we right align the time series and use padding and masking to make them equal in length. Each code is represented by a one-hot vector with 4,893 dimensions (number of ICD-9 codes). Training, validation, and testing set is in 0.75:0.1:0.15 ratio. For model, we set 2 local heads and 2 global heads. We choose α = 1 logarithmic time decay with year as the decay unit. We apply Adam Optimizer [39] with α = 0.001 , β 1 = 0.9 and β 2 = 0.999 . We use the learning rate decay method α current = α initial · γ global step decay steps with decay rate γ = 0.98 and decay step = 2000 [40] . Before the prediction task, we carry out the pre-train step and use the early stop with 5 epochs. We use the fivefold cross-validation. The code implementation is publicly available at https:// github. com/ SCXsu nchen xi/ MTGRU Baselines • Time-aware methods (RNN-based methods) • GRU [14] . It uses GRU to embed visits and make the final prediction. • T-LSTM [8] . It uses elapsed time weight to change previous memory in LSTM. • Co-occurrence-aware methods (Word2Vec-based methods) [4] . It applies the skip-gram model and multi-layer perceptron to get the representation of codes and visits. • Dipole [1] . It uses BRNN along with three attention mechanisms to measure the relation of different visits for the final prediction. TCoN predicts more accurately than all baselines. The results of binary classification (mortality, readmission, sepsis, and heart failure) and multi-classification (next diagnoses) are shown in Table 2 (a, b). Baselines may not match EHR characteristics and partially model data features. For example, T-LSTM has the worst performance as it is not suitable for short visit sequences like MIMIC-III. TCoN performs well on imbalanced datasets. In binary classification tasks, all datasets are imbalanced, especially the sepsis dataset (6.16%). But the results show that the more imbalanced the data, the greater the advantage of TCoN over baselines. TCoN can accurately predict multiple diagnoses in the next admission. In the multi-classification task, we evaluate methods with k = 5, 15, 25, 35. As shown in Table 2b , as k increases, the accuracies of all methods decrease, but the advantage of our approach is still obvious. We change the dimension of representation vector in hidden layers. The results in Fig. 3a show that TCoN performs better than other methods under all dimensions. Then, we set different numbers of heads for TCoN. Figure 3b shows that the number of heads = 2 is the key turning point. We choose a patient numbered 32,790 in MIMIC-III (a white man with 3 admission records and died at 80) to describe how TCoN produces the interpretation path. Figure 4a is the heat map of α for the death prediction. The diagnosis 'hypoxemia' contributes the most to the last admission as its weighted vector's norm is the biggest. For 'hypoxemia' , the closest diagnosis is 'pulmonary collapse' with the biggest α * i = 0.892 . For 'pulmonary collapse' , the closest diagnosis is 'unspecified pleural effusion' with the biggest α * i = 0.803 . And for 'unspecified pleural effusion' , the closest diagnosis is 'unspecified sleep apnea' with the biggest α * i = 0.782 . So far, an interpretation path 'unspecified sleep apnea -> unspecified pleural effusion -> pulmonary collapse -> Hypoxemia -> death' is found as shown in Fig. 4b . Figure 4c shows cases of interpretation paths of sepsis prediction and heart failure prediction. Each path is the summary results by using the most frequent diagnosis. Thus, we find sepsis-related pre-diagnoses/symptoms, such as 'Fever' , 'Chills' , 'Immunity disorders' , ' Anemia' and 'Coma' . And we find heart failure-related pre-diagnoses/ symptoms, such as 'Ventricular fibrillation' , 'Myocarditis' , 'Coronary atherosclerosis' and 'Hypertension' . In recent years, deep learning (DL) technology has shown its superior performance in medical applications [41] [42] [43] [44] , such as medical image recognition [45] and [46] . And many methods have achieved good performance for specific disease prediction, such as Alzheimer's disease [47] , sepsis [48] , and heart disease [49, 50] . However, most of them pursue the task accuracy but ignoring the interpretability. DL-based approaches are black-box models, which is not easy to understand for non-professionals, especially doctors without artificial intelligence backgrounds. Thus, the explainable DL method is needed. This study aims at this problem and puts forward a solution, interpretation path, to make the predictions explainable. In EHR, the patient's records are irregular in time due to the unpredictability of the diseases and inevitable data loss. The current disease could be more closely related to the disease a week ago than the disease a year ago [8, 9] . Thus, the time perception mechanism is needed. This study aims at this issue and proposes a time gate to explicitly learn the irregular time information by the time decay function. The experiments show that using two kinds of head for relations of inter-visit and intra-visit is necessary. The difference between these two relations is not just the time interval, but also the pathology. We emphasize the code relations are more likely to be complications in the same visit, but causations and continuities among different visits. For example, in our experiments, the relation of 'diabetes' with 'cellulitis and abscess of legs' in one visit is more prone to be a short-term complication, but the relation of 'diabetes' and 'long-term use of insulin' in two different visits is more prone to be causation. Thus, for each patient, we can give a disease association graph. The weight of the edges between two diagnoses in the same admission represents the adjoint coefficient, and the weight of the edges between two diagnoses in different admissions represents the causal coefficient. Figure 5 shows the diagnosis graph case of patient 32,790. The interpretation path is not symmetrical, which means α ij = α ji . α ij = # of i−j occurrences #of i occurrences and α ji = # of i−j occurrences # of j occurrences , they have different denominators. For example, code i , j , k represent the diagnoses of 'malaria' , 'fever' , 'periodic cold fever' respectively. In our experiment, i is mostly accompanied by j as α ij = 0.762 . But j is not always accompanied by i as α ji = 0.023 . It is mostly accompanied by code k with α ki = 0.701 . Comparing α ji and α ki , the results show that 'periodic cold fever' is a better explanation for 'malaria' than 'fever' . In research [51] , 'periodic cold fever' is a special clinical manifestation of 'malaria' and there are very few other diseases with this symptom. It illustrates that our interpretable method can explain the results by reflecting the relation (such as complication, causation, and continuity) between the diagnoses and α * i is a more important standard to find the maximum co-occurrence code for i than α i * . In medical applications, the data is usually imbalanced. The normal state of patients is the majority, while the disease records may be the small sample. But the small sample is more important for the disease prediction. Thus, a DL model should be robust on the imbalanced dataset. In this paper, our pre-train and fine-tune framework can help. Further, there is room for further improvement. The current modeling method is based on pure EHRs data. Integrating prior information will make the results of the data relation modeling and medical prediction more accurate and reasonable. The available method is knowledge graph embedding based on ICD code. Besides, more data in EHRs such as doctor notes, medications, and laboratory tests can be used for better performance. Future work will focus on these aspects. The data-driven medical prediction method based on interpretable deep learning is essential for healthcare management. In this paper, we propose an interpretable Time-aware and Co-occurrence-aware Network (TCoN) for data modeling and medical prediction. It can perceive hierarchical data structures with the time relation and the co-occurrence relation, give an interpretation path to explain the prediction, and build a diagnosis graph for every patient. The experiments show that TCoN outperforms the state-ofthe-art methods. Dipole: diagnosis prediction in healthcare via attention-based bidirectional recurrent neural networks Learning to diagnose with LSTM recurrent neural networks RETAIN: an interpretable predictive model for healthcare using reverse time attention mechanism Multi-layer representation learning for medical concepts Mime: multilevel medical embedding of electronic health records for predictive healthcare Bone disease prediction and phenotype discovery using feature representation over electronic health records Deep computational phenotyping Patient subtyping via time-aware lstm networks DeepCare: a deep dynamic memory model for predictive medicine Temporal convolutional neural networks for diagnosis from lab tests Risk prediction with electronic health records: a deep learning approach Exploiting convolutional neural network for risk prediction with medical feature embedding ImageNet classification with deep convolutional neural networks Empirical evaluation of gated recurrent neural networks on sequence modeling Learning complex, extended sequences using the principle of history compression Bidirectional recurrent neural networks Long short-term memory Distributed representations of words and phrases and their compositionality Efficient estimation of word representations in vector space A neural probabilistic language model A scalable hierarchical distributed language model Word representations: a simple and general method for semi-supervised learning Time-dependent graphs: definitions, applications, and algorithms Neural machine translation by jointly learning to align and translate Effective approaches to attentionbased neural machine translation Show, attend and tell: neural image caption generation with visual attention Image captioning with semantic attention Multiple object recognition with visual attention Attention-based models for speech recognition Teaching machines to read and comprehend convenient online submission • thorough peer review by experienced researchers in your field • rapid publication on acceptance • support for research data, including large and complex data types • gold Open Access which fosters wider collaboration and increased citations maximum visibility for your research: over 100M website views per year submit your research ? Attention is all you need Mimic-III, a freely accessible critical care database The third international consensus definitions for sepsis and septic shock (sepsis-3) Time to treatment and mortality during mandated emergency care for sepsis The precision-recall plot is more informative than the roc plot when evaluating binary classifiers on imbalanced datasets pre-training of deep bidirectional trans-formers for language understanding Improving language under-standing by generative pre-training Adam: a method for stochastic optimization Decoupled weight decay regularization DeepCare: a deep dynamic memory model for predictive medicine An empirical evaluation of generic convolutional and recurrent networks for sequence modeling Machine learning approaches to drug response prediction: challenges and recent progress Deep learning for improved risk prediction in surgical outcomes Research on recognition of medical image detection based on neural network GAMENet: graph augmented MEmory networks for recommending medication combination Multi-task dictionary learning based on convolutional neural networks for longitudinal clinical score predictions in Alzheimer's disease. In: HBAI@IJCAI Model-based reinforcement learning for sepsis treatment Using recurrent neural network models for early detection of heart failure onset K-margin-based residual-convolution-recurrent neural network for atrial fibrillation detection When is fever malaria? Publisher's Note Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations This paper is dedicated to those who want to fight COVID-19. 1 CS and HL conceptualized the idea. HL initialized and supervised the project. CS collected data, implemented the experiments, and drafted the manuscript. HD reviewed the manuscript and implemented the additional experiments. All authors provided a critical review of the manuscript and approved the final draft for publication. This work was supported by the National Natural Science Foundation of China (No. 62172018, No. 62102008) and the National Key Research and Development Program of China under Grant 2021YFE0205300 to collect and process data and publish the paper. The code implementation is publicly available at https:// github. com/ SCXsu nchen xi/ MTGRU. The data is at https:// mimic. physi onet. org. Not applicable. No financial competing interests.