key: cord-0121196-o3sblyst authors: Gao, Junyi; Sharma, Rakshith; Qian, Cheng; Glass, Lucas M.; Spaeder, Jeffrey; Romberg, Justin; Sun, Jimeng; Xiao, Cao title: STAN: Spatio-Temporal Attention Network for Pandemic Prediction Using Real World Evidence date: 2020-07-23 journal: nan DOI: nan sha: a7db59411ec407b5dbd4d2b69ccd432c2d0a1b64 doc_id: 121196 cord_uid: o3sblyst Objective: The COVID-19 pandemic has created many challenges that need immediate attention. Various epidemiological and deep learning models have been developed to predict the COVID-19 outbreak, but all have limitations that affect the accuracy and robustness of the predictions. Our method aims at addressing these limitations and making earlier and more accurate pandemic outbreak predictions by (1) using patients' EHR data from different counties and states that encode local disease status and medical resource utilization condition; (2) considering demographic similarity and geographical proximity between locations; and (3) integrating pandemic transmission dynamics into deep learning models. Materials and Methods: We proposed a spatio-temporal attention network (STAN) for pandemic prediction. It uses an attention-based graph convolutional network to capture geographical and temporal trends and predict the number of cases for a fixed number of days into the future. We also designed a physical law-based loss term for enhancing long-term prediction. STAN was tested using both massive real-world patient data and open source COVID-19 statistics provided by Johns Hopkins university across all U.S. counties. Results: STAN outperforms epidemiological modeling methods such as SIR and SEIR and deep learning models on both long-term and short-term predictions, achieving up to 87% lower mean squared error compared to the best baseline prediction model. Conclusions: By using information from real-world patient data and geographical data, STAN can better capture the disease status and medical resource utilization information and thus provides more accurate pandemic modeling. With pandemic transmission law based regularization, STAN also achieves good long-term prediction performance. Pandemic diseases such as the novel coronavirus disease (COVID-19) has been spreading rapidly across the world and poses a serious threat to global public health. Up to July 2020, COVID-19 has affected 14.1 million people and caused more than 597K deaths over the world [1] and caused significant disruption to people's daily life as well as huge economic losses. Therefore, it is critical to predict the pandemic outbreak early and accurately to help make following policies and reduce losses. Many epidemiological models (e.g., susceptible-infected-removed (SIR), susceptible-exposedinfected-removed (SEIR)) and deep learning models (e.g., Long Short Term Memory networks -LSTM) have been applied to predict the COVID-19 pandemic [1] [2] [3] [4] . However, they face three major limitations: (1) They usually build a separate model for each location (e.g., one model per county) without incorporating geographic proximity and interactions with nearby regions. Or the forecasts are only depend on some observed patterns from other locations [2, 3] , while inter-regional interactions can provide valuable information for future progression. In fact, a location often shows similar disease patterns with its nearby locations or demographically similar locations due to population movements or demographic similarity [5] . (2) Existing models are mainly built on COVID-19 case report data. These data are known to have serious under-reporting or other missingness issues. (3) Epidemiological models that use disease transmission dynamics such as SIR and SEIR are designed to understand the long-term trends but may sacrifice short-term prediction accuracy. Conversely deep learning-based models can only predict known data patterns, and lead to accurate predictions only within a short time period. Therefore, while there are techniques that allow for either short-term, or long-term predictive models of disease outbreaks, existing models do not provide accurate models over both time horizons. In this work, we propose a new Spatio-Temporal Attention Network for pandemic prediction using real world evidence, named STAN. We map locations (e.g., a county or a state) to nodes on a graph and construct the edges based on geographical proximity and demographic similarity between locations. Each node is associated with a set of static and dynamic features extracted from multiple real-world evidence in medical claims data that capture disease prevalence at different locations and medical resource utilization conditions. We use the graph convolutional network (GCN) with attention mechanism to incorporate interaction of the different neighboring locations of a node. Then we predict the number of infected patients for a fixed period into the future while concurrently imposing physical constraints on predictions according to transmission dynamics of epidemiological models. We apply STAN to predict both state-level and county-level pandemic progression, achieving up to 87% lower mean squared error compared to the best baseline model. Traditional epidemic prediction models use compartmental or agent based models that hardcode predefined disease transmission dynamics at population level, such as SIR, SEIR and their variants [2, 3] . Some works also utilize time series learning approaches for pandemic prediction, for example, applying curve-fitting [3] or autoregression [4] . Besides these traditional statistical models, deep learning models were developed to cast epidemic or pandemic modeling as time series prediction problems. Many works [5] [6] [7] combines deep neural networks (DNN) with causal models for influenza like illness (ILI) incidence forecasting. Deng et al. [8] proposed a graph message passing framework to combine learned feature embeddings and an attention matrix to model disease propagation over time. However, DNN-based methods have a major issue that they can only predict known trend from the input data. Thus at the early stage of the pandemic, if all input data are increasing, it is unlikely for these models to predict a decline trend in future. Yang et al. [3] used previous pandemic data to pretrain the LSTM, and then apply it to predict COVID-19 progression in China. However, different pandemics have different infect ability, so it may lead to inferior prediction results if the model transfer previous pandemic progression directly at the early stage of the pandemic. Kapoor et al. [9] utilize simple graph neural network for COVID-19 prediction. However, their model only predict the next day instead of long-term progression. It is still challenging to make deep-based model achieve good long-term prediction performance by utilizing observed pandemic progression patterns from neighboring locations or transmission dynamics. Recently, several studies have attempted to incorporate knowledge about physical systems into deep learning. For example, Wu et al. and Beucler et al. [10, 11] introduced statistical and physical constraints in the loss function to regularize the predictions of the model. However, their studies only focused on spatial modeling without temporal dynamics, besides regularization being ad-hoc and difficult to tune the hyper-parameters. Seo et al. [12, 13] integrate physical laws into graph networks. However, they focused on using physic laws to optimize node-edge transitions instead of prediction results, which only predicted graph signals for the next time point instead of long-term outcomes. In our work, we also incorporate physics laws, i.e., disease transmission dynamics to regularize model predictions to overcome the limitations of prior models. These regularizations will be applied over a time range to ensure we can predict long-term pandemic progressions. Besides, these regularizations are applied on extracted temporal and spatial feature embeddings of locations as an extra loss term, so it does not introduce extra hyper-parameters and is more unlikely to cause gradient exploding or vanishing. In this paper, we develop the model STAN to predict the number of COVID-19 positive cases for a fixed number of days into the future, at a county or state level across the USA. STAN takes the following input data: county-level historical number of positive cases, county-level population related statistics and relevant medical codes extracted from medical claims data. Our goal is to better predict the number of cases by utilizing the rich amount of information captured by these different modalities of data. Throughout the paper we use to denote the number of spatial locations (counties), to denote the feature matrix of size × ( ! + " ) where ! is the number of "static" features per county and " is the number of "dynamic" features for each county. denotes the total number of time steps (i.e, days) over which we have the data for. Finally, we are interested in predicting ( ), the number of infected patients at the #$ time step for all the locations. We construct the location graph using location-wise dynamic and static features as nodes and geographic proximity as edges. The graph is fed into graph convolutional networks with attention mechanism to extract spatio features and learn the graph embedding for the target location. Then the graph embedding is fed into the GRU to extract temporal relationships. The hidden states of GRU will be used to predict future number of infected and recovered cases. We use an additional physical loss based on pandemic transmission dynamics to optimize the model. Figure 1 , STAN is enabled by the following components: 1) a graph neural network that capture the geographic trends in disease transmission; 2) an RNN that captures the temporal disease patterns in each location; 3) Both short-term prediction loss and long-term physical law constraint loss to regularize learned hidden representations of node embeddings. We describe each of these aspects below. In order to capture the spatio-temporal epidemic/pandemic dynamics, we represent the input data as a 3D tensor, with location (e.g., states, counties, etc), time stamp (e.g., days/weeks) and the features (both static and dynamic) associated with each location as the three dimensions. Along with this, we consider the geographic proximity and demographic similarity between the different locations. Graph nodes: We construct an attributed graph ( , ℰ) to represent the input data. Each location is modeled as a graph node and is associated with a feature matrix that contains both static and dynamic features across all the time stamps for that location. In total we have 3007 nodes, one for each county in the United States. Graph edges: The edges are constructed based on geographical proximity and the population of the nodes (i.e., locations). In particular, we designate the weight of an edge between nodes and where % and & are the populations of the nodes, %& is the geographical distance between them, , and are hyperparameters. The above model is based on the idea that disease transmission patterns highly depend on crowd mobility. If there is a high mobility rate between a pair of nodes, then it can be expected that the nodes have similar disease spread parameters. Hence, our process includes an edge with a large weight between such a pair of nodes. Note that the distance parameter %& can incorporate any notion of distance, including air traffic. Finally, we threshold the weights to retain only the edges with significant weights. Each node % has an associated static feature vector of size 4, consisting of the static features including latitude, longitude, population and population density. Node features (dynamic) -Each node % also has a set of dynamic features in the form of a matrix. The dynamic features include the number of active cases, total cases, current number of hospitalizations due to COVID-19, and the number of each of the 48 COVID-19 related diagnosis and procedure codes extracted from claims data according to the Centers for Disease Control and Prevention guideline (https://www.cdc.gov/nchs/data/icd/COVID-19-guidelines-final.pdf). We outline the specific diagnosis and procedure codes used in description of the dataset. Obtaining the complex spatial dependencies is a key problem to pandemic prediction. By utilizing spatial similarity, the model can make more accurate predictions for a focused location with considering the disease transmission status of its similar locations. Here we employ the Graph Convolutional Networks (GCN) model to extract such spatial features. For each location, the GCN model can obtain the topological relationship between this location and its similar locations. Concretely, we use a two-layer GCN to extract spatial features from our previous constructed graph data. We use both the latest data and part of historical data within a sliding window to construct the graph. Mathematically, denote the set of node features as ∈ ℝ +×-×. # / where , , are the number of nodes in the graph, number of time stamps and number of features at each time stamp, 0 denotes the length of input sliding window. " is the number of dynamic features per node. A two-layer GCN can be expressed as: where denotes the normalized adjacency matrix of the graph , H = Q 2 $ % Q Q 2 $ % denotes preprocessing step, Q = + + is a matrix with self-connection structure, Q = ∑ Q %& & is a degree matrix. 1 and ! denotes the weight matrix in the first and second layer. Furthermore, we consider the real-world scenario that neighboring locations may have different impact on the infectious status of the focused location. For example, if one city has a large population and increasing infected cases, this city may have larger impact on its neighboring cities. In order to model such practice, we use graph attention mechanism (GAT) on GCN layers. GAT learns the hidden embeddings of each node by iteratively using node feature for similarity computation as: 3 is used to cast the input to another feature space of -dimension. Then the attention coefficient is calculated as: Following a self-attention strategy [14] , we use multi-head to calculate independent attention mechanisms and then sum all heads up to obtain the final representation # 6 Q for -th node as: where 4 denotes the weight matrix for the -th head. The obtained node embedding # 6 Q contains spatial features extracted from the graph. We also want to utilize historical temporal patterns to better predict future trend. Concretely, we input the node embedding to Gate Recurrent Unit (GRU) [15] network to learn temporal features. Since we build different models for each location, we use max-pooling to integrate embeddings of all nodes and also reduce the embedding dimension (all following equations are for one specific location and we omit location index to reduce clutter): then we calculate GRU's hidden representation as: The obtained # can be regarded as the final embedding at -th time stamp for the specific location, which contains all important spatial and temporal features learned from real world data. Our objective is to predict future number of infected cases accurately from both long-term and also short-term. In our method, we tackle this issue by using a multi-task learning framework to jointly consider short-term and long-term prediction performance. The idea is to use both short-term prediction loss and long-term physical law constraint loss to regularize learned hidden representations of node embeddings # . In order to achieve this, the model output consists of two tasks: 1. Transmission/Recovery rate. The traditional SIR-based model simply assumes that the disease transmission/recovery rate and won't change with time. But in practice, they may easy to change due to policies or disease evolution reasons. To solve this issue, we define a prediction window 8 that the and won't change within this window. So the prediction labels will also be segmented into / 8 parts and 'a time stamp' actually refers to a prediction window. At each time stamp, the model will predict # and # for the next prediction window as: where MLP(⋅) denotes the multi-layer perceptron and we use sigmoid activation since both and are between 0 and 1. 2. Number of infected/recovered cases. At each time stamp, the model will predict the increment of number of infected and recovered cases Δ # v and Δ # v as: Note that Δ # v and Δ # v are vectors since we are predicting for 8 days. Then the final predicted number of infected and recovered cases can be simply calculated as: where cm yyyyy⃗ denotes the cumulative sum operation from left to right, #2! and #2! denote the actual number of infected and recovered cases at the day before current prediction window. The loss function also consists of two parts: 1. Physical constraint loss. The first loss term is a physical law constraint loss to regularize longterm prediction trends. Based on the obtained transmission and recover rate and SIR differential equations, we can calculate the physical law-based increment number of infected cases and recovered cases as: where 0 denotes the first day within current prediction window, denotes the population of current location. After obtaining the first Δ #1 9 g and Δ #1 9 | , the following days can be calculated iteratively. Then we can calculate the number of infected and recovered cases # 9 H , # 9 v for entire prediction window. Finally, the physical constraint loss is calculated as: where # and # denotes the ground truth number of infected and recovered cases. This loss term calculates the mean squared error of physical law-based predictions, so that we can make the prediction results in line with the long-term trend of pandemics. 2. Prediction loss. The second loss term is a regular mean squared error loss for the second task: this loss term is to make the prediction results as close as possible to the short-term variation. By combining the two loss terms, the final loss function can be calculated as: = * + 9 In this paper, we used a US county-level dataset that consists of COVID-19 related data from two resources: Johns Hopkins University (JHU) Coronavirus Resource Center and IQVIA's claims data. The data from JHU Coronavirus Resource Center was collected since Mar 22, 2020. It has the number of active cases, confirmed cases and deaths related to COVID-19 for different locations in the US. We select states that have more than 1000 confirmed cases by May 17 to ensure the data source accuracy and finally we got 45 states and 193 counties. For such counties, we set the number of cases before their respective first record dates as zero. The IQVIA's claims data is from the IQVIA US9 Database. We collected patient claim and prescription data from Mar 22, 2020, from which we obtain the number of hospital visits per county per day as well as the term-frequency of each medical code outlined in Table 1 . The dataset has records for a total of 453,089 patients across the entire timespan of the JHU dataset. There are 48 unique ICD-10 codes related to COVID-19 that were claimed from the set of codes considered (detailed table is shown in the Supplementary Material). We compare STAN with the following baselines. 1. SIR: the susceptible-infected-removed (SIR) a basic disease transmission model that uses differential equation to simulate epidemic. S, I and R represent the number of susceptible, infected, and recovered individuals. 2. SEIR: the susceptible-exposed-infected-removed (SEIR) epidemiological model as another physical constraint-based baselines. Compared to the SIR model, SEIR adds exposed population to the equation. We input the latest number of infected case into a naïve GRU and predict future numbers. ColaGNN uses location graph to extract spatial relationships for predicting pandemics. Different from STAN, graph nodes in ColaGNN only consists of time series of numbers of infected cases. 5. CovidGNN [9] : CovidGNN uses graph neural network with skip connections to predict pandemics. They use the graph embedding to directly predict future number of cases without using RNN to extract temporal relationships. In order to explore the performance enhancement by physical constraints and graph structures, we also compare STAN with following reduced models. 1. STAN-PC removes physical constraints from STAN. The implementation details of all models are shown in Supplementary Material. We have made our codes available on a public repository (https://github.com/v1xerunt/STAN). We predict future number of active cases on both county-level and state-level. In order to evaluate the ability of STAN for both long-term predictions and short-term predictions, we set the prediction window 8 to 5, 15 and 20, i.e., predict for future 5, 15 and 20 days. All training sets start from Mar 22 and all test sets start from May 17. We also split 8 days from the training sets as evaluation sets to determine model hyper-parameters. We use the mean square error (MSE), mean absolute error (MAE) to evaluate our model. We also use the average concordance correlation coefficient (CCC) to evaluate the results. The CCC measures the agreement between two variables, and it is computed as: is the correlation coefficient between the two variables. Note that we do not use the coefficient of determination ( " ) is because the range of " is (−∞, 1), so some extreme value may significantly affect the average value. But the range of CCC is between -1 and 1, so we can evaluate model results more reasonably. The results show STAN can conduct more accurate long-term and short-term prediction than SIR and SEIR model on both state-level and county-level. Since county-level graph data is more granular, so STAN can benefit more by utilizing such data compared to the traditional dynamicsbased model. It is also worth to note that both reduced model STAN-PC and STAN-Graph also outperform other baselines. This indicates that both physical constrains and real-world evidence provide valuable information for pandemic progression prediction. We report the detailed performance of each location in the Supplementary Material. In this section, we will discuss the advantages and also limitations of our model. We draw the predicted curve of future 20 days from May 16 to Jun 5 for two counties, El Paso, TX and Lake, IN. As shown in Table 3 , for the two counties, STAN shows up to 99% relatively lower MSE compared to the SEIR and SIR model. As shown in Figure 2 , the curve also fits the actual trend better for both counties. One obvious drawback of SIR and SEIR model is the overfitting issue. The SIR and SEIR model tend to predict the peak will come right after current data, which is especially obvious in the prediction curve of Lake. This is because these traditional models do not incorporate the influence and interdependency of transmission between geographic regions. The characteristics of transmission of communicable diseases in one region are unlikely to be decoupled from the those of nearby regions unless there are barriers to interaction between the regions such as topography (rivers with limited bridges or mountain ranges with limited road connections) or controlled borders. Such decoupling is infrequently present between counties in the USA. The inability to account for this geographic interdependency removes an important variable in the SIR and SEIR models, and impedes their ability to predict the future progression using limited data at the early pandemic stage. Though deep learning-based methods can achieve better performance compared to traditional statistical methods in various time series analysis and prediction tasks, there are still two major limitations in our work. The first limitation is about the prediction window setting in our method. The traditional SIR and SEIR model take data from all timesteps as input, while our model divide historical data into prediction windows. Though this setting does provide higher flexibility to learn changeable transmission and recovery rate, it makes the model sensitive to the data quality within each window. If the number of cases fluctuates drastically due to inaccuracy in data collecting process or the pandemic situation is temporarily controlled, it is difficult for the model to learn valid transmission and recovery rate. This issue can be further solved by applying dynamic data smoothing to get a smoother curve. Another limitation of is that the physical constraints may be too simple to reflect the real-word situation such as home isolation and pandemic control policies. There are lots of researches focus on improving the traditional SIR model by adding more population groups and transmission equations. In our future work, we can easily extend the physical constraints in the same way. In this work, we propose a spatio-temporal attention network model (STAN) for the COVID-19 pandemic prediction. We map locations (e.g., a county or a state) to nodes on a graph. We use a set of static and dynamic features extracted from multiple real-world evidence including real world medical claims data to construct nodes and use geographical proximity and demographic similarity between locations to construct edges. We use the graph convolutional network with attention mechanism to incorporate variant influence of the different neighboring locations of a node and predict the number of infected patients for a fixed period into the future. We also impose physical constraints on predictions according to transmission dynamics. STAN achieves better prediction performance than the traditional SIR and SEIR model and shows less overfitting issue at early stage of the pandemic. We hope our model can help government and researchers better allocate medical resources and make policies to control the pandemic earlier. Our model can also be easily extended to predict hospitalization of COVID-19 in our future work. Junyi Gao and Rakshith Sharma implemented the method and conducted the experiments. All authors were involved in developing the ideas and writing the paper. The IQVIA's claims data is from the IQVIA US9 Database. We collect patient claim and prescription data from Mar. 22, 2020, from which we obtain the number of hospital visits per county per day as well as the term-frequency of each medical code. The dataset has records for a total of 453,089 patients across the entire timespan of the JHU dataset. There are a total of 48 unique ICD-10 codes related to COVID-19 that were claimed from the set of codes considered, as shown in Table 4 . All methods are implemented in PyTorch 1.1 and trained on a server equipped with an Intel Xeon E5-2620 Octa-Core CPU, 256GB Memory and a Titan V GPU. For the hyper-parameters of baseline models, we follow the recommended setting if it is available in the original paper. Otherwise, we determine its value by grid search on the validation set. For the STAN model, the hidden dimension of the GRU is set to 200 and the hidden dimension of the MLP is set to 100. The graph embedding dimension is set to 400 and the graph attention dimension is set to 650. The input sliding window 0 is set to 6. For the GRU model, the hidden dimension is set to 100. For the ColaGNN, the hidden dimension of GRU is set to 256, the dimension of graph node embedding is set to 500. For the CovidGNN model, we use a two-layer GNN and the dimension of graph node embedding is set to 256. We report the prediction MSE for each state and county in Table 5 and Table 6 . Due to space limits, we only take SIR and SEIR into comparison. For all 45 states, when 8 = 5, STAN achieves the best performance on 37 states; when 8 = 15 and 20, STAN achieves the best performance on 35 states. For all 45 states, when 8 = 5, STAN achieves the best performance on 37 states; when when 8 = 15, STAN achieves the best performance on 148 states; when 8 = 20, STAN achieves the best performance on 143 states. While conducting county-level prediction, STAN can achieve better long-term prediction compared to SIR and SEIR on most locations. This is due to the location graph is more granular and the model can extract detailed spatial interactions between nodes. And also for some locations, the pandemic haven't outbreak, so STAN can better predict future progression by considering progressions from neighboring locations. While aggregating the data and conducting state-level predictions, STAN's performance is more consistent over all length of prediction window. COVID-19 pandemic data. Secondary COVID-19 pandemic data Initial Simulation of SARS-CoV2 Spread and Intervention Effects in the Continental US Modified SEIR and AI prediction of the epidemics trend of COVID-19 in China under public health interventions Time series analysis by state space methods DEFSI: Deep learning based epidemic forecasting with synthetic information Multi-step-prediction of chaotic time series based on co-evolutionary recurrent neural network Prediction of chaotic time series of rbf neural network based on particle swarm optimization. Intelligent Data analysis and its Applications Graph message passing with cross-location attentions for long-term ILI prediction Examining COVID-19 Forecasting using Spatio-Temporal Graph Neural Networks Enforcing statistical constraints in generative adversarial networks for modeling chaotic dynamical systems Enforcing analytic constraints in neuralnetworks emulating physical systems Differentiable physics-informed graph networks Physics-aware Difference Graph Networks for Sparsely-Observed Dynamics Attention is all you need. Advances in neural information processing systems Empirical evaluation of gated recurrent neural networks on sequence modeling NY_DUTCHESS 12570 16050 The authors have no competing interests to declare.