key: cord-0136694-uabdmy4z authors: Liu, Sicen; Wang, Xiaolong; Xiang, Yang; Xu, Hui; Wang, Hui; Tang, Buzhou title: CATNet: Cross-event Attention-based Time-aware Network for Medical Event Prediction date: 2022-04-29 journal: nan DOI: nan sha: eb4744c5bc49d5ddb19573397fb8cdc9f866faf2 doc_id: 136694 cord_uid: uabdmy4z Medical event prediction (MEP) is a fundamental task in the medical domain, which needs to predict medical events, including medications, diagnosis codes, laboratory tests, procedures, outcomes, and so on, according to historical medical records. The task is challenging as medical data is a type of complex time series data with heterogeneous and temporal irregular characteristics. Many machine learning methods that consider the two characteristics have been proposed for medical event prediction. However, most of them consider the two characteristics separately and ignore the correlations among different types of medical events, especially relations between historical medical events and target medical events. In this paper, we propose a novel neural network based on attention mechanism, called cross-event attention-based time-aware network (CATNet), for medical event prediction. It is a time-aware, event-aware and task-adaptive method with the following advantages: 1) modeling heterogeneous information and temporal information in a unified way and considering temporal irregular characteristics locally and globally respectively, 2) taking full advantage of correlations among different types of events via cross-event attention. Experiments on two public datasets (MIMIC-III and eICU) show CATNet can be adaptive with different MEP tasks and outperforms other state-of-the-art methods on various MEP tasks. The source code of CATNet will be released after this manuscript is accepted. Nowadays, with the development of electronic medical record (EMR) systems, massive clinical data about patients is available for medical research and healthcare [1] [2] [3] [4] [5] [6] [7] [8] . Many machine learning methods, especially deep learning methods, have been used for medical event prediction (MEP) [9] [10] [11] [12] [13] [14] [15] (as shown in Fig. 1 ) to support clinical decisions. Formally, given historical medical records of a patient p with demographic information and visits, each of which includes various kinds of medical events (usually denoted by medical codes) such as diagnoses , procedures , laboratory tests (labtests) , medications and other medical events ? , the goal of medical event prediction is to predict the medical events at time T+1, denoted by +1 . The medical events at time T+1 may be diagnoses, procedures, labtests, medications or other medical events that are of the same types as the medical events in the historical records, or of new types different from the medical  corresponding author events in the historical records. According to the type of target events, MEP can be classified into risk prediction [16] [17] [18] , patient treatment trajectory [19, 20] , readmission prediction [21] [22] [23] [24] , mortality prediction [17, [25] [26] [27] [28] , diagnosis code prediction [29] [30] [31] [32] , prescription prediction [33, 34] etc. These MEP tasks usually face the following three challenges: 1) Information decay rates with time are related to temporal irregular intervals between two neighbor visits (i.e., ∆ ) as well as medical events. The same medical events with irregular temporal intervals should correspond to different decay rates, and the different medical events should be decayed in different ways. 2) There are some correlations among medical events in a visit, which should be considered appropriately. For instance, when a patient with acute renal failure and his/her potassium values exceed the safe limit in a laboratory test, the physician will prescribe Furosemide, Calcium Gluconate, and Potassium Chloride to the patient. 3) If the medical events in historical records are of the same type as target medical events, these types of medical events should be paid more attention than other types of medical events. Methods for MEP should be task-adaptive. Most machine learning methods focus on solving one or two of the three challenges. For example, T-LSTM [35] is a timeaware LSTM (long short-term memory network) designed to incorporate the temporal interval between two neighbor visits into the memory cell of the basic LSTM unit to adjust the information transmission from the previous visit to the next visit. It improves the performance of the standard LSTM. HiTANet [18] is a hierarchical time-aware attention network based on Transformer, which models temporal information in local and global stages to imitate the decision-making process of doctors in risk prediction. HiTANet achieves much better performance than T-LSTM on three disease prediction tasks. Using the same structure of T-LSTM, the heterogeneous LSTM (LSTM-DE) [33] adds a gate into the basic LSTM unit to model potential relations between two types of heterogeneous information, where the medical events of the same types as the target events are primary events and the other medical events of different types from the target events are auxiliary events. RetainEx [36] is an attention-based recurrent neural network (RNN) for risk prediction, which learns event-level weights within a visit as well as visit-level weights for interpretability. The attention mechanism is used to model correlations among medical events in a visit. In addition, RetainEx also considers the irregular temporal intervals by incorporating visit temporal intervals into the input of To tackle all the three challenges mentioned above, we design a novel neural network based on a novel attention mechanism, called cross-event attention-based time-aware network (CATNet), for medical event prediction. In CATNet, we regard temporal intervals as medical events a "new" type and design a novel attention mechanism (called cross-event attention) to model correlations among medical events in the same visit, including temporal intervals. The cross-event attention has two forms corresponding to the type of target medical events to make CATNet task-adaptive. Moreover, in order to model temporal information in a global aspect, CATNet also introduces a global time convertor to consider the time medical events decay as impact factors for MEP. In summary, the proposed CATNet method has the following contributions: • We design a time-aware, event-aware, and task-adaptive network to model heterogeneous and irregular time series medical data. • We regard temporal information as a "new" type of medical event and propose a novel attention mechanism (crossevent attention) to learn information decay rate for each visit and correlations among medical events in each visit in a unified way. This attention mechanism is time-aware and task-adaptive. • We introduce a visit-level attention to model the relations among historical visits, and a global time convertor to model temporal information globally. • CATNet can be used as a universal framework for MEP. We conduct experiments on two public real-world datasets to verify the effectiveness of CATNet we propose. Ablation studies and result analysis show the effectiveness of the proposed method. In recent years, various deep learning methods have been proposed for medical event prediction on EMR data, including multi-layer perception (MLP) [37] , convolutional neural network (CNN) [20, 38, 39] , RNN [21, 26, 34, 40] and Transformer [18] [41] . Among them, RNN is the representative one, and Transformer also shows great potential. We review these studies from three perspectives corresponding to the three challenges mentioned above. Most of the time-aware deep learning methods are extended from basic RNN, GRU (gated recurrent unit network), or LSTM that is suitable for sequential data with identical temporal intervals. T-LSTM [35] uses a decay function of the temporal interval to control information transmission between neighbor visits and integrates the function into the memory cell of the basic LSTM unit. It achieves much better performance than LSTM on risk prediction. StageNet [17] is a stage-aware LSTM that integrates temporal intervals into the basic LSTM cell to represent the disease progression stage. Timeline [41] applies a time-aware disease progression function to determine how much disease information is transmitted from the current visit to the target visit before inputting visit representation into LSTM. DATA-GRU [11] and RetainEx [36] incorporate visit temporal intervals as additional features into the input vectors of RNN. Besides RNN and LSTM extensions, the attention mechanism [42] , [43] , and time-aware graph neural networks [34] are also used to model irregular temporal intervals. ConCare [44] applies time-aware attention to the output of GRU. HiTANet [18] , a Transformer-based method, uses hierarchical time-aware attention to model temporal information at local and global levels. RGNN_TG_ATT [34] connects medical events in neighbor visits with edges weighted by temporal intervals to form a time-aware graph. Some studies try to investigate medical event prediction at both visit-level and event-level. RetainEx [36] adopts an attention mechanism to model the correlations among medical events in a visit. ConCare [44] uses GRU to model each event sequence. Task-adaptive methods consider the types of target events. They assume that medical events in historical records of the same type as target events should be paid more attention to. LSTM-DE [33] proposed for medication prediction considers medications as primary events and labtests as auxiliary events and adopts a similar network with T-LSTM to emphasize the primary events by introducing a decomposition structure into LSTM. RGNN_TG_ATT extends LSTM-DE by introducing an additional time-aware graph neural network. DA-LSTM [45] is an extension of T-LSTM that uses a context fusion module to enhance auxiliary inputs by considering current inputs and whole input sequences instead of the decomposition structure in T-LSTM. Most existing deep learning methods focus on one or two challenges of MEP. The proposed CATNet considers the three challenges of MEP in a unified framework comprehensively. This section defines some notations and describes the MEP tasks. Medical events are events related to healthcare in EMRs, including medications, diagnoses, labtests, procedures, mortality status, etc. They are usually normalized by different types of codes. In this study, we only consider five types of medical events, that is, medications, diagnoses, labtests, procedures, and mortality status, each of them is denoted by , ∈ { , , , , }, where m, d, l, p, and mo denote medication, diagnosis, labtest, procedure, and mortality respectively. Suppose that there are codes in , codes in , codes in and codes in . A patient p with T historical visits ( 1 , 2 , ⋯ , ) and demographics S can be represented by a three-dimensional binary matrix as shown in Figure 1 , where denotes the event type of events in the ith visit , indicates whether the k-th event code is present in the i-th visit if the k-th event code appears in the i-th visit, Given a patient with T visits and demographics S, = [ 1 , 2 , ⋯ , , * , ], the goal of medical event prediction is to predict which medical events will appear in the next visit. In this study, we investigate the following two cases: 1) the historical visits include the medical events of the same type as the target events. For example, medication prediction according to medications, diagnoses, labtests, and procedures in the historical visits. 2) the target events are of new types different from all medical events in the historical visits. For example, mortality prediction according to medications, diagnoses, labtests, and procedures in the historical visits. This section presents our proposed method (as shown in Fig. 2 where ∆ 1 ∈ , ∆ 1 ∈ , ∆ 2 ∈ × and ∆ 2 ∈ are all parameters. Thus, the visit with temporal interval Figure 2 (b) gives two cases of cross-event attention, the upper is an example of task-unware attention, where the task is mortality prediction according to medications, diagnoses, labtests, and procedures in historical visits, the lower is an example of task-aware attention, where the task is medication prediction with medications as primary events and others as auxiliary events. For visit embeddings , we design the cross-event attention and apply it on them to obtain time-aware, event-aware, and task-adaptive visit embeddings (For details, refer to the cross-event attention section). We further input the new visit embeddings into a common sequence modeling network (e.g, GRU, LSTM, or Transformer) to obtain the hidden representation of each visit ℎ . Visit-level attention and global time convertor are applied to ℎ and ∆ to utilize temporal information as well as historical medical events at visit levels and global levels respectively. The representation obtained by visit-level attention ℎ and the representation obtained by the global time convertor ℎ are combined together. We also consider the demographic information of the patient as the static characteristic. Finally, a fully connected network with the sigmoid function is used for prediction. In the following sections, we present cross-event attention, local attention, global time convertor, and event prediction in detail. There are two cases of cross-event attention according to the type of the target events: task-unware attention and taskaware attention. The task-unware attention corresponds to the case that the target events are of new types different from all events in the historical visits, while the task-aware attention corresponds to the case that the historical visits include the medical events of the same type as the target events (i.e, primary events). We apply self-attention on only primary events (take medication prediction as an example) in the task-aware attention (Eq. (2)), but all events in the task-unware attention (Eq. (3)) as follows: = ( where ̂∈ ( + + + +1)× is a matrix of all event and temporal embeddings. The output of cross-event attention is set as: Because of the attention weights between the primary events (or all events) and temporal intervals, our proposed CATNet is time-aware. As the attention is applied at event level, CATNet is event-aware. Moreover, the attention mechanism can be adaptive to different tasks. Therefore, we obtain time-aware, event-aware, and task-adaptive visit embeddings via cross-event attention. For convenience, the CATNet using task-unware attention is denoted by CATNet-I, and that using task-aware attention is denoted by CATNet-II. CATNet-II can be regarded as a specific version of CATNet-I. Any sequence modeling network can be used as a backbone network to model historical visit sequences, such as GRU, LSTM, and Transformer. Suppose that the hidden representation of the time-aware, event-aware, and task-adaptive visit embedding sequence [ 1 , 2 , … , ] can be obtained via the following equation: where ℎ ∈ is the hidden state for the t-th visit by aggregating all the medical information, and Backbone is a sequence modeling network. After obtaining ℎ, we employ visit-level attention to generate attention weights for each visit as follows: where ℎ 1: is a matrix of hidden states for visits from 1 to T. Then a patient can be represented as follows: The local attention module can capture the local relations among historical visits. Besides temporal intervals between neighbor visits, we also consider the times medical events decay. A global time convertor is designed to generate the impact score of each medical decay time globally. Similar to Eq.(1), a global medical event decay time ∆ can be embedded as: where ∆ 1 ∈ , ∆ 1 ∈ , ∆ 2 ∈ ×b and ∆ 2 ∈ are all parameters. Furthermore, we adopt the sigmoid activation function to obtain the impact score of each medical decay time global as follows: where σ is the sigmoid activation function. Then we obtain a patient representation considering temporal information globally: For event prediction, we consider the patient representation obtained by the local attention module, the patient To obtain the representation of the patient's demographic information S, we apply a fully connected network to S: where ∈ × and ∈ are parameters. Finally, a fully connected network with the sigmoid function is used to make a binary (vector) prediction as follows: where ∈ ×( + ) and ∈ ℝ are parameters, and is the total number of events that need to be predicted. Binary cross-entropy is used as the loss function for each event. We conduct experiments on two public real-world datasets, i.e., MIMIC-III [46] and eICU Collaborative Research Dataset [47] to evaluate the performance of the proposed CATNet. labtest prediction, procedure prediction, and mortality prediction. We apply CATNet with task-aware attention to the former four-event prediction tasks and CATNet with task-unware attention to mortality prediction. The data contains the medical records of over 200K patients admitted to ICUs between 2014 and 2015. We also select the patients with at least two visits for experiments. Three types of medical events, that is, medications, labtest, and procedures, are used for event prediction. Finally, we obtain 9215 patients and apply CATNet with task-aware attention to medication prediction, labtest prediction respectively. The statistics details of the selected data in the two datasets are summarized in Table 1 , where "# *" denotes the number of *, "Avg" and "Max" are abbreviations of "averaged" and "maximum". In this study, we split each dataset into training, validation, and test sets across users with ratios of 8:1:1 in all experiments. We compare CATNet with the following state-of-the-art methods: 1) DoctorAI [9] : This method is a classical medical event prediction method that uses the basic RNN to represent patient historical visits. 2) T-LSTM [35] : This method is a time-aware LSTM designed to incorporate the temporal interval between two neighbor visits into the memory cell of the basic LSTM unit to control the information transmission. 3) LSTM-DE [33] : This method is a task-aware LSTM that introduces a gate into the basic LSTM unit to model potential correlations between two types of medical events. It has a similar structure as T-LSTM. They fall into two categories: RNN variants (1-7) and Transformer variants (8) . It should be noted that among the eight baseline methods, LSTM-DE and RGNN_TG_ATT cannot be directly applied to mortality prediction as they are taskadaptive, and we only apply them on medication prediction as they can only consider one auxiliary type of medical events. In addition, we can only conduct mortality experiments on the MIMIC-III dataset as only it contains mortality status. We implement CATNet in the PyTorch framework and use the source codes of DoctorAI 3 , T-LSTM 4 , ReTainEx 5 , LSTM-ED 6 , RGNN_TG_ATT 7 , StageNet 8 , ConCare 9 , and HiTANet 10 as their implementations. For all the methods, we train the models with randomly initialized parameters 10 times on the training sets For all the methods, we train the models with randomly initialized parameters 10 times on the training sets, save the best models on the validating datasets at each time, and test the best models on the test sets. The mean performances of the 10 times independent runs on the test datasets and their standard deviations are reported in the "Experimental and Analysis" section. Three metrics: the area under the receiver operating characteristic curve (AUC), the area under the precision-recall curve (AUPR), and Top-recall are used as performance evaluation metrics. During model training, we set epochs as 100 epochs, the learning rate as 0.0001, the hidden size of CATNet on the MIMIC-III dataset as 256, and 512 on the eICU dataset, the dropout rate as 0.3, and the other hyperparameters as default. The comparisons of CATNet with the other methods for medicine prediction, diagnosis prediction, labtest prediction, and procedure prediction on the two datasets in AUC and AUPR are reported in Table 2 and Table 3 , where "NA" denotes no result. The comparisons of CATNet with other methods for mortality prediction on the MIMIC-III dataset in AUC and AUPR are reported in Table 4 . By analyzing the results, the following conclusions could be summarized. Table 3 Auc and Aupr of different methods on medication prediction, diagnosis prediction, procedure prediction, and labtest prediction on the eICU dataset Firstly, the proposed CATNet, no matter CATNet-I or CATNet-II, shows stable and outstanding performance and achieves state-of-the-art scores on most metrics with small deviations. The superiority of CATNet over the other methods is greater on the MIMIC-III dataset than on the eICU dataset. In addition, the differences between CATNet in AUPR are much bigger than those in AUC. Secondly, as a method based on basic RNN, the performance of DoctorAI is not outstanding, but stable. It even performs much better than serval RNN variants such as T-LSTM, RetainEx, and ConCare on several prediction tasks. For example, DoctorAI outperforms T-LSTM on all the prediction tasks in Table 2 and Table 3 . The reason is that some prediction tasks may depend on many factors. All of them should be comprehensively considered like CATNet. From the results shown in Table 2 and Table 3 , we can see that CATNet has a strong task-adaptive ability. Table 2 on the MIMIC-III dataset Thirdly, comparing the two CATNet methods, CATNet-I shows a little better performance in AUC and much better performance in AUPR than CATNet-II, indicating that the task-aware attention is meaningful for the type of tasks listed in Table 2 and Table 3 . Furthermore, we also show the Top-k (k=10, 20, …, 50) recalls of different methods for the four tasks listed in Table 2 on the MIMIC-III dataset. As shown in Fig. 3 , we can see that CATNet outperforms all the other baseline methods consistently on all the tasks, although the difference is small on some tasks (i.e., labtest prediction). This result is consistent with that in Table 2 . To investigate the importance of each component of CATNet, we compare CATNET_II with its variants that remove some parts of the full CATNet using the same settings as the previous experiments. We still run 10 times to obtain the average performance. The results are shown in Table 5 , where "w/o", "Cross", "Vis", "Global" denotes "without", "cross-event attention", "visit-level attention", "global time converter" respectively. It should be noted that we do not remove the "local attention" module from CATNet using Transformer as it is the intrinsic module of Transformer. We can see that all the components contribute to CATNet. From the results, we could conclude that the performance of CATNet decreases when any module is abandoned because the "cross-event attention" module learns correlations among medical events in a visit appropriately, the "local attention" module has the ability to learn relations among historical visits, and the "global time convertor" module is able to model medical event decay rates globally, which simulates the clinical review of the patient historical record to pay attention to each temporal point relatively. We also conduct an ablation study on each type of medical event, including the "new type" for temporal intervals, and the results are shown in Table 6 . In the case of medication prediction, all the other four types of medical events are beneficial to medication prediction. It indicates that the "cross-event attention" module in CATNet has the ability to capture correlations among different events in a visit again, and also has the ability to capture correlations between temporal intervals and medical events. From the ablation study and experimental results listed in Table 5 , we can conclude that using cross-event attention can significantly improve the performance of medical event prediction. To further illustrate the effectiveness of cross-event attention in CATNet, we conduct case studies to interpret the learned cross-event attention weights and visualize them. Figure 4 shows the attention weights between randomly select three medications and the top 10 diagnoses most related to them. We can see that the cross-event attention can find the strong relationships between the medications and diagnoses. For example, "calcium gluconate" is the conventional medication for "sideroblastic anemia", "ac kidny fail, tubr necr", and "mixed acid-bas bal disorder" This paper proposes a novel time-aware, event-aware, and task-adaptive deep learning method for medical event prediction, namely, cross-task attention-based time-aware Transformer (CATNet). In CATNet, irregular temporal intervals between neighbor visits are regarded as a new type of medical event and cross-event attention is designed to model correlations among different types of medical events including temporal intervals. The cross-event attention contains two cases corresponding to task-aware and task-unware. Because of the cross-event attention, CATNet is time-aware, event-aware, and task-adaptive. In addition, CATNet also considers each visit at local and global levels. Experiments on two public real-world datasets show the effectiveness of the proposed CATNet and outperform other state-of-the-art methods on various medical event prediction tasks. Big Data In Health Care: Using Analytics To Identify And Manage High-Risk And High-Cost Patients Big Data for Health Medical analytics for healthcare intelligence -Recent advances and future directions A guide to deep learning in healthcare Deep EHR: A Survey of Recent Advances in Deep Learning Techniques for Electronic Health Record (EHR) Analysis Recommendations for enhancing the usability and understandability of process mining in healthcare How to identify and treat data inconsistencies when eliciting health-state utility values for patient-centered decision making Automated machine learning: Review of the state-of-the-art and opportunities for healthcare Doctor AI: Predicting Clinical Events via Recurrent Neural Networks Learning the Joint Representation of Heterogeneous Temporal Events for Clinical Endpoint Prediction DATA-GRU: Dual-Attention Time-Aware Gated Recurrent Unit for Irregular Multivariate Time Series GluNet: A Deep Learning Framework for Accurate Glucose Forecasting Uncertainty-Aware Deep Ensembles for Reliable and Explainable Predictions of Clinical Time Series EPTs-TL: A two-level approach for efficient event prediction in healthcare Improving prediction for medical institution with limited patient data: Leveraging hospital-specific data based on multicenter collaborative research network High-risk Prediction of Cardiovascular Diseases via Attention-based Deep Neural Networks Stage-Aware Neural Networks for Health Risk Prediction HiTANet: Hierarchical Time-Aware Attention Networks for Risk Prediction on Electronic Health Records On Clinical Event Prediction in Patient Treatment Trajectory Using Longitudinal Electronic Health Records Dynamic Prediction in Clinical Survival Analysis Using Temporal Convolutional Networks Mining high-dimensional administrative claims data to predict early hospital readmissions Predicting Hospital Readmission: A Joint Ensemble-Learning Model A Public-Private Partnership Develops and Externally Validates a 30-Day Hospital Readmission Risk Prediction Model A Medical Distance Based Manifold Learning Approach for Heart Failure Readmission Prediction n Sepsis mortality prediction based on predisposition, infection and response Mortality prediction in pediatric trauma Deep Learning with Heterogeneous Graph Embeddings for Mortality Prediction from Electronic Health Records Prediction of Patient Length of Stay on the Intensive Care Unit Following Cardiac Surgery: A Logistic Regression Analysis Based on the Cardiac Operative Mortality Risk Calculator Time Series Forecasting for Healthcare Diagnosis and Prognostics with the Focus on Cardiovascular Diseases Diagnosis Prediction via Medical Context Attention Networks Using Deep Generative Modeling Enhancing Dialogue Symptom Diagnosis with Global Attention and Symptom Graph Knowledge-based Attention Model for Diagnosis Prediction in Healthcare A Treatment Engine by Predicting Next-Period Prescriptions A hybrid method of recurrent neural network and graph neural network for next-period prescription prediction Patient Subtyping via Time-Aware LSTM Networks Visual Analytics with Interpretable and Interactive Recurrent Neural Networks on Electronic Medical Records Enhancing Multi-layer Perceptron for Breast Cancer Prediction Medical Hyperspectral Image Classification Based on End-to-End Fusion Deep Neural Network DeepAISE -An interpretable and recurrent neural survival model for early prediction of sepsis Interpretable Representation Learning for Healthcare via Capturing Disease Progression through Time RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism Diagnosis Prediction in Healthcare via Attention-based Bidirectional Recurrent Neural Networks Personalized Clinical Feature Embedding via Capturing the Healthcare Context Prediction of Treatment Medicines with Dual Adaptive Sequential Networks MIMIC-III, a freely accessible critical care database The eICU Collaborative Research Database, a freely available multi-center database for critical care research The authors declare no conflict of interest.