key: cord-0213441-9u132ex7 authors: Gupta, Garima; Vig, Lovekesh; Shroff, Gautam title: DRTCI: Learning Disentangled Representations for Temporal Causal Inference date: 2022-01-20 journal: nan DOI: nan sha: f543589a29a11aab66528be60259a25f9aeb3d29 doc_id: 213441 cord_uid: 9u132ex7 Medical professionals evaluating alternative treatment plans for a patient often encounter time varying confounders, or covariates that affect both the future treatment assignment and the patient outcome. The recently proposed Counterfactual Recurrent Network (CRN) accounts for time varying confounders by using adversarial training to balance recurrent historical representations of patient data. However, this work assumes that all time varying covariates are confounding and thus attempts to balance the full state representation. Given that the actual subset of covariates that may in fact be confounding is in general unknown, recent work on counterfactual evaluation in the static, non-temporal setting has suggested that disentangling the covariate representation into separate factors, where each either influence treatment selection, patient outcome or both can help isolate selection bias and restrict balancing efforts to factors that influence outcome, allowing the remaining factors which predict treatment without needlessly being balanced. Critical medical decisions pertaining to treatment plan selection require reasoning about potential future outcomes given a patient's current state and observed longitudinal clinical data. Standard randomized control trials, the most reliable technique for evaluating different treatment options, are often impractical in the temporal setting, as it may not be feasible to conduct a trial for every temporal variation of a particular treatment plan. Thus there is an acute need to estimate the effects of different future treatment sequences over time from observational patient data. However, a major hurdle in this endeavour is the presence of time-varying confounders, or patient covariates that are causally influenced by past treatments and also influence future treatments and outcomes (Cole et al., 2010) . Consider a Covid-19 patient having low fever (time varying confounder) in the early stages of the infection. The patient is then administered steroids which spikes the fever and prompts administration of anti-virals, but the patient does not survive. Without accounting for the time varying confounder (fever), one may reach the incorrect conclusion that anti-virals are harmful to covid patients. A notable and closely related attempt to address the problem of time-varying confounders for evaluation of treatment plans is the recently proposed Counterfactual Recurrent Network (CRN) (Bica et al., 2020) which uses a GRU based recurrent model to predict potential future treatment effects by forcing hidden representations Φ(·) that can balance out Figure 1 . Disentanglement of static covariates (x) into factors ΦΓ(x), ΦΥ(x) and Φ∆(x) determining treatment A, outcome (Y ) or both respectively (Hassanpour & Greiner, 2019) arXiv:2201.08137v1 [cs. LG] 20 Jan 2022 time-varying confounding. However, these balanced representations curtail the representation of covariates which influence both future treatment assignment and outcome prediction. To mitigate this, we take inspiration from the work in (Hassanpour & Greiner, 2019) (Figure 1 ) and disentangle the temporal representation into the factors that exclusively determine future treatment (Φ Γ (·)), response (Φ Υ (·)) or both treatment and response (Φ ∆ (·)). Extending the work by (Negar & Russell, 2019) to learn disentangled, temporal representations, we propose Disentangled Representations for Temporal Causal Inference (DRTCI), a novel sequence-to-sequence architecture based technique to evaluate potential treatment plans. DRTCI's novelty lies in the disentanglement of the latent temporal representation into three factors which influence either treatment assignment, response or both (confounding). The resulting factorization prevents needless balancing of factors that solely influence treatment assignment, and results in improved response prediction over full state balancing (Bica et al., 2020) , especially when the degree of confounding is high. The disentanglement is more effective for predicting future treatment outcomes, as only the relevant factors Φ ∆ (·) and Φ Υ (·) of the representation are being used to predict the outcome, and the selection bias is isolated to the factors Φ Γ (·) and Φ ∆ (·) while balancing via adversarial loss is restricted to Φ Υ . It can also be argued that balancing can be employed on all factors, however since balancing is time dependent, the assumption of temporal variation of treatment and that of non confounders being correlated might not hold. We apply balancing selectively in contrast to the CRN which balances the entire representation including the Φ Γ (·) factor which has no influence on the future outcome and more crucially Φ ∆ (·) which represents the confounding factors between treatment and future outcome y. Instead, we demonstrate improved prediction by retaining the confounding factors Φ ∆ (·) for outcome prediction and mitigating the resulting selection bias via importance weights that are learnt by conditioning treatment prediction solely on the Φ ∆ (·) factor via an independent arm of the model. Our contributions in this paper are as follows: 1) We propose a novel recurrent model DRTCI to predict future treatment effects in a temporal setting with time-varying confounding 2) We demonstrate that disentangling the recurrent hidden representations in DRTCI yields substantial improvements to the prediction of future treatment effects over the recently proposed CRN model, the current state of the art 3) We further show that the disentanglement is especially effective in situations with extreme confounding and is also beneficial for early prediction with less historical patient data 4) We conduct ablation studies to determine which components of DRTCI are contributing most to the improved performance and provide an analysis of the same. While the field of causal inference has made significant strides towards counterfactual estimation, solutions have largely been focused on settings where the treatment and outcome depend only on a static snapshot of the covariates. These include statistical techniques that employ reweighting of the data instances to balance the source and target distributions (Austin, 2011) , (Bottou et al., 2013) , (Swaminathan & Joachims, 2015) which while unbiased, are known to suffer from high variance. The recent application of representation learning techniques (Shalit et al., 2017) , (Johansson et al., 2016) , have shown promise especially for settings with high cardinality of treatments (Ankit et al., 2020) , and continuous treatment dosages but these too are not applicable to temporal treatment settings with time-varying confounding and can not capture the effects of time-varying exposure to treatments. An additional regularization was proposed to augment the representation learning by using propensity scores to ensure that data points (x) that were nearby in the data space are also nearby in the representation space (Φ(x)) i.e. by matching P (A|x) and P (A|Φ(x)), where A is the treatment (Sharma et al., 2020) , (Yao et al., 2018) . However these methods suffer when there are factors in the data that only determine treatment and do not influence the outcome as these factors will affect the similarity matching when ideally they should be excluded prior to the matching (Negar & Russell, 2019) . A recent model that is closely related to this paper is the work by (Hassanpour & Greiner, 2019) which demonstrated the benefits of learning disentangled representations for counterfactual regression (see Figure 1 ). The authors argue that learning a single representation that removes the selection bias completely is not ideal, as their may be confounding factors that influence not only treatment selection but also outcome. Therefore, many of the prior proposed loss functions for representation learning (Shalit et al., 2017) had to play a delicate balancing act of removing selection bias, yet retaining sufficient predictive information in the representations to make accurate counterfactual predictions. Notably, while this approach led to superior counterfactual estimation, it was again restricted to non-temporal setting. Recent attempts at addressing causal inference in the temporal setting advocate utilizing recurrent neural networks to model the effects of time-varying exposures. This includes the work by (Lim, 2018) who utilize a sequenceto-sequence model to estimate the inverse probability of treatment weights (IPTW) for improving Marginal Structural Models (MSMs) (Robins et al., 2000) , (Mansournia et al., 2017) and the more recent CRN model described in the previous section. We describe the problem of estimating the response for factual or counterfactual treatments in a longitudinal setting. Further, we highlight the major components and loss functions employed by our proposed approach for temporal causal inference. We make the standard assumptions made by (Robins et al., 2000) to estimate treatment effects i.e. consistency, positivity and no hidden confounders (sequential strong ignorability) (Pearl & Mackenzie, 2018) . for N patients containing time dependent patient covariates x i t and static covariates v i for each patient i. Each patient receives one of K treatments from the available treatment options in the set {A 1 , .., A k , ...A K } at each timestep. The treatment received at each timestep is denoted by a one hot vector A response y i t+1 is observed at each timestep and is appended as part of the covariates x i t+1 . The patient superscript i is omitted in the text later for simplicity. Let a course of τ treatments be received from timestep t to t + τ − 1 asā t,τ = [a t , a t+1 , .., a t+τ −1 ] with history h t and current covariates x t available at timestep t. The patient history at timestep t is represented by a sequence of the patient's time-varying covariatesx 1,t−1 = [x 1 , ..., x t−1 ], sequence of treatmentsā 1,t−1 = [a 1 , ..., a t−1 ] and static covariates v as h t = [x 1,t−1 ,ā 1,t−1 , v]. We are interested in estimating the response y t+τ for a course of treatments a t,τ received at timestep t with history h t and covariates x t as observants as in Eq. 1: An impediment in the estimation of y t+τ from observational data is the presence of bias introduced during treatment assignment. Covariates x t and history h t at timestep t causally affect treatment assignmentā t,τ , and are thus timevarying confounding in nature. To compensate for the effect of time-varying confounders for unbiased causal inference in the longitudinal scenario, we propose a (DRTCI) model for learning disentangled representation for temporal causal inference described below. In DRTCI, we employ a recurrent neural network to map the observant comprising of the (variable-length) patient history h t and covariates at timestep t to a latent, fixed length state representation s t : We propose to learn three disentangled representations of the latent state at each timestep. These representations are: 1) Outcome Representation (Φ Υ (s t )) which affects future outcome y t+1 , 2) Confounding Representation (Φ ∆ (s t )) which affects both future treatment assignment a t and outcome y t+1 , and 3) Treatment Representation (Φ Γ (s t )) which affects treatment assignment a t , are learned in a disengaged manner as discussed next. These disentangled representations are learned by passing the state at each timestep through three disentangled, two-layered ELU Multi-Layer Perceptrons (MLPs): , parameters of which are learned using a set of four loss functions namely Prediction Loss, Treatment Loss, Imbalance Loss, and Weighting Loss: Prediction Loss & Weighting Loss: The concatentation of confounding, outcome representation and treatment at timestep t is passed to a regression arm (W Υ ) comprising of MLP-ELU-MLP layers to predict outcomeŷ t+1 . We use squared error loss for learning the parameters of confounding and outcome representations. However, since the objective is to predict factual as well as counterfactual outcome while the observational dataset contains only factual outcomes, we use an importance weighting of squared error to emphasize those instances that are more useful for the prediction of counterfactual outcomes. This importance weight is an extension of binary treatment importance weighting introduced in (Negar & Russell, 2019) to K treatments: In Eq. 4, p(a t ), p(A k ) are the marginal probabilities for factual treatment a t and treatment option A k respectively and are computed from the observational dataset. Further, propensity p(A k |Φ ∆ (s t )) 1 is obtained by passing Φ ∆ (s t ) via an W ∆ (·) (MLP and a softmax layer) with L t,W weighting loss as: ). The weighted square error is then used as outcome regression loss (L t,P ) as: Treatment Loss: Treatment loss is employed for learning the logging policy that guides treatment assignment. For this, both the confounding (Φ ∆ (s t )) and treatment (Φ Γ (s t )) representations are passed through W Γ (·) (MLP and a softmax layer) to learn the propensity of treatment A k as p(A k |[Φ ∆ (s t ), Φ Γ (s t )]) 1 ∀k. Treatment loss (L t,T ) is then used for learning the weights of confounding and treatment representations as: Imbalance Loss: To overcome the bias due to time-varying confounders, we use an adversarial imbalance loss (L t,I ): L t,I (Φ Υ ) = − K k=1 a t (k)log(p(A k |Φ Υ (s t ))) which ensures that the outcome representation (Φ Υ (s t )) is not predictive of treatment assignment, i.e, p(Φ Υ (s t )|A 1 ) = ... = p(Φ Υ (s t )|A K ). We compute propensity p(A k |Φ Υ (s t )) 1 by passing outcome representation via randomly initialized, fixed weight one-layer MLP followed by softmax. We shall describe the usage of loss functions for obtaining disentangled state representations in the proposed architecture of DRTCI in Section 4 and Figure 2 . In this section, we describe the architecture and procedure for predicting outcome y t+τ (Eq. 1). We use a sequenceto-sequence architecture as described in (Bica et al., 2020; Lim, 2018) with the added novelty of using disentangled state representations in the longitudinal setting. The encoder and decoder network of sequence-to-sequence architecture is discussed in the text that follows: The encoder network uses a recurrent network with an LSTM unit to process the observants comprising of history h t and current covariates x t to obtain state representation s t (Eq. 2). Further, the state representation is disentangled into an outcome representation (Φ E Υ (s t )), a confounding representation (Φ E ∆ (s t )) and a treatment representation (Φ E Γ (s t )) using the loss at timestep t as where β is obtained using hyperparameter tuning as described in section 5.3 and superscript E stands for parameters pertaining to encoder. Outcome representation, confounding representation and treatment at timestep t is then used to forecast the outcome at the next timestepŷ t+1 as described in Eq. 3. Our larger objective is to obtain y t+τ (Eq. 1) i.e, forecast the outcome for a sequence of τ future treatments. To that end, we use a sequence-to-sequence architecture, which extends the next timestep ahead outcome prediction using our encoder to a τ -step ahead prediction using a decoder, where an optimal encoder is learned first, then a decoder is updated as described in the text that follows. We concat disentangled representations from the optimal encoder to obtain a unified representation Representation Ψ t is obtained for each timestep for N patients. Observational dataset is processed by splitting each patient's trajectory into shorter sequences of τ timesteps: {{Ψ t } ∪ {y t+m , a t+m , y t+m+1 } τ −1 m=1 ∪ {v}} T −τ t=1 for training the decoder. The representation Ψ t is used to initialize the state of the decoder recurrent network and to finally obtain a state representation for the patient m-timesteps in the future: where inputs to the LSTM unit of the decoder network are static covariates v, previous timestep's outcome y t+m and previous treatment a t+m−1 with superscript D pertaining to decoder parameters. The state representation (s t+m ) is , which is the same loss as in Eq. 6 computed for m timesteps ahead of the t th timestep for learning decoder parameters. We then predict outcomeŷ t+m+1 for m + 1 timesteps ahead of t as: and proceed iteratively for all values of m upto τ − 1 timesteps ahead of t. Figure 2 demonstrates sequence-to-sequence network architecture for next step ahead prediction and τ -step ahead prediction. It should be noted that the true y t+m is not available as input during test time for the decoder. Hence, previous timestep's outcome predicted by decoderŷ t+m are auto-regressively used as input to decoder's recurrent network (Eq. 7) for prediction ofŷ t+m+1 (refer Figure 2) . In this section, we discuss the experimental set-up used for measuring the efficacy of the proposed model. First we describe the dataset, followed by baselines, evaluation metrics and implementation specifics for evaluation of DRTCI. As real world datasets lack counterfactual outcomes, we use a bio-model (Geng et al., 2017) for simulating a lung cancer dataset which contains the outcomes for different treatment options: no treatment, chemotherapy, radiotherapy, and both chemotherapy and radiotherapy treatments on tumour growth volume in a longitudinal manner. The generation DRTCI Figure 2 . Figure illustrates architecture for DRTCI with encoder which builds disentangled representations that are unified as Ψt to initialize decoder which updates continuously to predict factual/counterfactual outcome to a sequence of treatments. model is the same as used by (Lim, 2018; Bica et al., 2020) and we briefly describe the simulation model in Section C. We generate a dataset for a maximum of 60 timesteps and evaluate DRTCI under different degrees of time-dependent confounding (γ c ,γ r ) and time horizons (τ ). For each setting of γ c , γ r , τ , we simulate 10000 patients for training, 1000 for validation and 1000 for testing. For testing of next step ahead prediction using an encoder, we simulate volume V (t + 1) for each timestep t under all treatment options for each patient while for τ -step ahead prediction using the decoder, we generate 2τ counterfactual sequences of outcomes with each sequence giving chemotherapy at one of t, ..., t+τ −1 and similarly radiotherapy at one of t, ..t+τ −1 timesteps for each timestep t in the patient's trajectory. We benchmark DRTCI against 1) Marginal Structural Models (MSMs) (Robins et al., 2000) which use logistic regression for inverse propensity weighting and linear regression for prediction, 2) Recurrent Marginal Structural Networks (RMSNs) which use a sequence-to-sequence architecture for handling confounders and outcome prediction. (Lim, 2018) and 3) the state of the art Counterfactual Recurrent Network (CRN) which introduces adversarial balancing for handling confounders and outcome prediction. We test the performance of DRTCI in terms of normalized root mean squared error % (NRMSE %) with NRMSE % = (RMSE/maximum tumour volume)*100, where RMSE= ( i t (ŷt+1−yt+1) 2 N T ) 1 2 for encoder, with y t+1 replaced by y t+τ for decoder error computation. Details of implementation specifics are discussed in Section A. In this section, we evaluate our proposed approach DRTCI and compare its performance with baselines for varying degrees of time-varying confounding and for τ -step ahead counterfactual predictions. We compare DRTCI for varying degrees of confounding: (a) γ c = γ r = 5, (b) γ r = 5, γ c = 0, (c) γ r = 0, γ c = 5 with the baselines approaches mentioned in Section 5.2. We analyze this comparison for τ -step ahead counterfactual prediction with τ increasing in steps of 1 from 3 to 7 in Table 1 . DRTCI substantially outperforms RMSN, MSM and the current state of art CRN for all different values of τ and degree of confounding (γ r , γ c ). We further analyze the performance of DRTCI for very high degrees of confounding by setting γ r = γ c = γ ∈ {6, 7, 8, 9, 10} for τ = 1 (next step ahead prediction) and τ = 3, 5 in Table 2 . Results show DRTCI performance for high confounding is superior than CRN for next step ahead prediction and better by large margins for τ -step ahead prediction achieving more than 32% and 25% reduction in errors for τ = 3 and τ = 5 respectively. We believe this is due to the fact that DRTCI learns confounding representation which plays a significant role in outcome prediction especially when the degree of confounding is high, which is in contrast to prior approaches where confounding covariates are balanced for reduction in bias and reliable outcome prediction. Moreover, while in prior approaches a shared 3.03% 3.06% 6.57% 8.32% 6.55% 6.24% 10 4.17% 4.26% 6.66% 11.32% 6.51% 8.53% Mean 2.64% 2.74% 4.98% 7.41% 4.85% 6.51% Table 2 . NRMSE % of DRTCI for high γ (high confounding) outcome regression arm would lead to noisy interference between the treatment predictions, in DRTCI since only the relevant factors are being used to predict outcomes, noisy interference is greatly reduced and common knowledge is better utilized across treatments. We also test the quality of treatment sequence prediction p 2 (s t ) in Section B for varying degrees of confounding. An extensive ablation study was carried out to identify the key factors contributing towards the performance of DRTCI. We perform this study for 5-step ahead prediction for varying degree of confounding with γ = 3, 7, 10. We examine the following conditions: (a) Φ Υ : DRTCI with no disentangled representation but only Φ Υ as the representation layer, (b) Φ Υ ⊕ Φ ∆ : DRTCI with outcome representation (Φ Υ (·)), confounding representation (Φ ∆ (·)) and no separate representation for only treatment (Φ Γ (·)), (c) Φ Υ ⊕Φ Γ : DRTCI with outcome representation (Φ Υ (·)), treatment representation (Φ Γ (·)) and no representation for confounding (Φ ∆ (·)) in Table 3 . Analyzing the representation layer which contributes most towards the performance, we see that the outcome representation (column Φ Υ ) alone does not provide significant performance gains while learning treatment or confounding representation along with outcome representation (Φ Υ ⊕ Φ ∆ , Φ Υ ⊕ Φ Γ ) gives a significant performance gain. It can be safely concluded that using the proposed representations which retain confounding and treatment factors such as in DRTCI, has the most significant impact on improvement of counterfactual prediction in longitudinal data, especially when the degree of confounding is high. with limited historical information for early treatment prediction. We evaluate the performance of DRTCI for next step ahead and 5-step ahead prediction for varying history sizes. We depict our analysis for high confounding with γ = 8 in Figure 3 . It is seen that for next step prediction, DRTCI performs marginally better than CRN with minimal history, while for 5-step prediction DRTCI performs significantly better for all history sizes. We also analyze DRTCI for different history lengths in Figure 4 using mean of NRMSE% values computed for high confounding. It is seen that DRTCI significantly outperform CRN for early prediction. Thus, when making distant early predictions in high confounding cases, DRTCI is more reliable. While many real world applications for causal inference involve temporal observational data, the literature has largely focused on techniques for the static data setting. Recent work has harnessed the power of representation learning to mitigate selection bias during counterfactual estimation, and has demonstrated the value of disentangling representations for counterfactual estimation in the static data setting. This paper presents DRTCI, a model that extends this disentanglement to the temporal setting and demonstrates significant benefits over the current state of the art for treatment prediction in a realistic medical scenario, especially for early predictions over long time horizons in high confounding situations. Future work involves extending this work for high cardinality treatments and to employ meta learning techniques to reduce data requirements and enhance adaptability to novel datasets similar to (Ankit et al., 2019) . Table 4 illustrates the search range of hyperparameters for encoder and decoder networks in DRTCI. From the search space of hyperparameters, we select optimal hyperparameters based on minimum NRMSE% in predicting the factual outcome for the validation dataset. For evaluation of NRMSE %, the maximum tumour volume in the data is 1150cm 3 We use a Tesla v100 with 4GB GPU, 32 GB RAM for training and hyperparameter optimization for our experiments. We test treatment sequence prediction on 1000 patients with treatment sampled from p 2 (s t ), whereâ t ∼ p 2 (s t ) and evaluated using accuracy as metric for different degrees of confounding γ. It is seen for increasing γ, accuracy of treatment prediction increases. This is because with increase in degree of confounding, factual treatment assignment is more biased and thus, is easier to predict. Table 5 . Accuracy of treatment sequence prediction To begin with, initial cancer stage and initial tumour volume is obtained from prior distributions for each patient. The volume of the tumour t days after diagnosis is then obtained using the following mathematical model: V (t + 1) = (1 + ρlog( κ V (t) ) − β c C(t) − (α r d(t) + β r d(t) 2 ) + e(t))V (t)) (9) where κ,ρ, β c , α r , β r , e t are sampled as described in (Geng et al., 2017) . Time-varying confounding is introduced by modelling chemotherapy and radiotherapy assignment as Bernoulli random variables with probability of chemotherapy (p c ) and radiotherapy (p r ) depending upon tumour volume: p c (t) = σ( γc Dmax (D(t) − δ c )) and p r (t) = σ( γr Dmax (D(t) − δ r )), where γ c , γ r are the factors controlling degree of time-varying confounding,D t is the average diameter over the last 15 days, D max = 13 cm, σ(·) is the sigmoid and δ r = δ c = D max /2. Chemotherapy concentration is given by C(t) = 5 + C(t − 1)/2 and radiotherapy dose is d(t) = 2 if applied at timestep t. The implementation of data simulator can be obtained from git repository 2 . It is to be noted that inspite that the representation size of DRTCI is three times that of baseline such as CRN, we simulate data of same size as that in (Bica et al., 2020) to resemble real-like data and for fare comparisons. Meta-learning for causal inference in a heterogeneous population Deep causal inference in high dimensions An introduction to propensity score methods for reducing the effects of confounding in observational studies Estimating counterfactual treatment outcomes over time through adversarially balanced representations Counterfactual reasoning and learning systems: The example of computational advertising Illustrating bias due to conditioning on a collider Prediction of treatment response for combined chemo-and radiation therapy for non-small cell lung cancer patients using a bio-mathematical model Learning disentangled representations for counterfactual regression Learning representations for counterfactual inference Forecasting treatment responses over time using recurrent marginal structural networks Handling time varying confounding in observational research Counterfactual regression with importance sampling weights The book of why: the new science of cause and effect Marginal structural models and causal inference in epidemiology Estimating individual treatment effect: generalization bounds and algorithms Multimbnn: Matched and balanced causal inference with neural networks The self-normalized estimator for counterfactual learning. advances in neural information processing systems Representation learning for treatment effect estimation from observational data