key: cord-0621131-15qvtl1v authors: Nguyen-Duc, Thanh; Zhao, He; Cai, Jianfei; Phung, Dinh title: MED-TEX: Transferring and Explaining Knowledge with Less Data from Pretrained Medical Imaging Models date: 2020-08-06 journal: nan DOI: nan sha: be0d3904e585ee6135a3a4f7774997c310963a5a doc_id: 621131 cord_uid: 15qvtl1v Deep learning methods usually require a large amount of training data and lack interpretability. In this paper, we propose a novel knowledge distillation and model interpretation framework for medical image classification that jointly solves the above two issues. Specifically, to address the data-hungry issue, a small student model is learned with less data by distilling knowledge from a cumbersome pretrained teacher model. To interpret the teacher model and assist the learning of the student, an explainer module is introduced to highlight the regions of an input that are important for the predictions of the teacher model. Furthermore, the joint framework is trained by a principled way derived from the information-theoretic perspective. Our framework outperforms on the knowledge distillation and model interpretation tasks compared to state-of-the-art methods on a fundus dataset. A practical scenario of medical image classification applications [1] is considered, where a central hospital headquarter gathers data from multiple local branches in Fig. 1(a) . The headquarter has developed a large CNN model for disease classification with excellent performance trained on a big dataset, which is the global model to be distributed to the branches. Given the limited computation, a branch wants to develop a customized smaller model using its local data. The branch cannot access to the big dataset of the headquarter because of privacy and sensitivity concerns. To assist the development of the local model, the knowledge from the global model is transferred to the local one [2] . For medical domain, model interpretation is highly desirable. Therefore, the local model should have two capabilities: explaining the global model and transferring the knowledge of the global model to the local model with its local data only. Model perceptive interpretation is defined by the ability to identify the areas of an input image that are important to the prediction of the classifier. Neural saliency such as Grad-CAM [3] is used to locate feature that contributes the most to the classification output. Feature selection [4] , hard * thanh.nguyen4@monash.edu, † corresponding author, † Monash University attention [5] and soft attention [6] are used to generate different weights for different features. However, they are not designed for explaining a pretrained global model. The recent Learning-to-Explain (L2X) [4] trains an explainer to explain a pretrained global model by maximizing the mutual information between selected instance-wise features and the teacher outputs. L2X does not address the issue of lack of large training data and its effectiveness on high-resolutional image classification has not been confirmed. Knowledge distillation is a process of transferring knowledge from the complicated global model (called teacher) to a smaller lighter-weighted one (called student). The small student model can significantly reduce the deployment cost of the local branch. KD was first introduced by Hinton et al. [2] to distill knowledge from the distribution of class probabilities predicted by the teacher model. Recently, Ahn et al. [7] exploited the information-theoretic perspective as maximizing the mutual information between the teacher and the student in order to transfer knowledge named (VID). In medical domain, Wang et al. [8] used KD to train a student model that speeds up the inference time of a 3D neuron segmentation model. However, these previous approaches do not consider to interpreting the complicated teacher model. In this paper, we propose an end-to-end framework to address the above two requirements simultaneously by to learn a small medical image classification model with less training data but better interpretability. Our contributions follow: a) a new end-to-end MEDical Transfer and EXplain framework (MED-TEX) from a pretrained global model, which combines knowledge distillation and pixel-level model interpretation. Existing methods only focus on either of them; b) a joint training objective for our framework, derived from an information-theoretic perspective. It is both theoretically and practically appealing; c) experimental results demonstrate that our proposed method outperforms other methods on model interpretability and knowledge distillation. We denote the CNN-based global classifier model (teacher, T ). The student S is another CNN-based classifier that can potentially be a hundred times smaller to significantly reduce computational complexity and be trained by less data only from the local branch. With input image, X ∈ R C×H×W (C, H, W are the channels, height and width of the image, re- Fig. 1 : (a) Problem setting: a headquarter gathers data from multiple branches to produce a shared cumbersome teacher. A branch builds a local small and interpretable model. (b) An overview of our framework: fixed pretrained teacher, learnable explainer and learnable student. The explainer explains to the student by producing a simplified X from input X. The knowledge from teacher is transferred to the student by maximizing the mutual information (M I). (c) The detailed architecture. spectively), the explainer E inspired by under-completed auto encoder with skip connections produces the selection scores Θ, which are high for the important pixels for the decision of the teacher and low for the unimportant ones. In our framework, Θ has the same size to X (at pixel level) and is elementwise multiplied by X to get a simplified the input image, denoted by X . This X is then input to the student S to perform predictions. Our goal is training the student to mimic the behaviors of the teacher by pushing teacher's outputs from the last and intermediate layers close to student's outputs while the explainer produces Θ to guide the student by highlighting the regions of X to generate X , as illustrated in Fig. 1(c) . Proposed framework. The teacher's and student's predicted distributions over the labels is denoted y T ∈ ∆ L and y S ∈ ∆ L respectively, where L is the number of labels and ∆ L denotes the L dimensional simplex. The T and S has M and N layers respectively, where the last layer is a fully connected layer and other layers are convolutional layers (or block convolution layers). N can be different from M as in [7] ; however, we simplify formulas by N = M . We have y T = T (X) (i.e., p(y T l | X) ∝ T (X) l ), X |X = E(X), and y S |X = S(X ) (i.e., q(y S l | X ) ∝ S(X) l ). We formulate our preliminary goals of explaining and extracting the teacher's knowledge to the student as the following loss derived from mutual information (derived from Eq. (10)). where q corresponds to our student, acting as the variational distribution in the deviation of mutual information. Eq. (1) is similar to minimizing the cross-entropy loss between the outputs of the teacher and the student and generate X by element-wise multiplication between X and Θ, aiming to push the predictions of the student close to those of the teacher, with the help from the explainer: Given an input image X, the explainer generates an importance score for each of its pixels, where its last layer is 1×1 convolution layer with sigmoid activation. The higher the important score is, the more important the corresponding pixel is to the prediction of the teacher. All the importance scores form the importance map, denoted as Θ ∈ [0, 1] C×H×W . The output of the explainer can be expressed as: where is the element-wise multiplication. Inspired by the idea of knowledge distillation in [7] , we therefore introduce an additional loss to maximize the mutual information between the outputs of each i th intermediate layer of the teacher (T i (X)) and the student (S i (X )). We simplify i = j in T i (X) and S j (X ) but it can be i = j as in [7] . where r(T i (X)|S i (X )) is a variational distribution used for approximating p(T i (X) | S i (X )), which is derived from information-theoretic perspective (see Eq. (12)). Recall that the output of the i th layer of the teacher is a C i × H i × W i feature map (note that the output of the i th layer of the student is of the same spatial dimension but with a smaller number of channels). Following [7] , we model T i (X) as the following Gaussian distribution conditioned on S i (X ): where µ i is a subnetwork with 1 × 1 convolutional layers to match the channel dimensions between T i (X) and S i (X ), µ i c,h,w is a single output unit, and σ i 2 c is the learnable parameter specific to each channel at the i th layer. For σ i 2 c , we exploit the softplus function σ i 2 c = log(1 + e α i c ) + where α i c is a learnable parameter and is used for numerical stability. With Eq. (5), we can write Eq. (4) as: Finally, the overall loss function of our framework can be written as where λ is the weight of the losses of the intermediate layers. Derivation from information-theoretic perspective. Previously, the objective function of our proposed framework has intuitive interpretations. Here we additionally demonstrate that the objective function can be derived in a theoretical way with mutual information, which is a widely used measure of the dependence between two random variables and captures how much knowledge of one random variable reduces the uncertainty about the other. In particular, we note: minimizing the training losses in Eq. (2) and Eq. (4) are equal to maximizing the following mutual information: I(X ; y T ) and I(T i (X); S i (X )), respectively. Given the definition of mutual information, the first term of Eq. (8) can be derived as: (9) In general, it is impossible to compute expectations under the conditional distribution of p(y T |X ). Hence, we define a variational distribution q(y T |X ) that approximates p(y T |X ): where D KL is the Kullback-Leibler divergence and equality holds if and only if q(y T |X ) and p(y T |X ) are equal in distribution. Note that it is not hard to show that our student corresponds to the variational distribution q. For the second term of Eq. (8), we have: (11) Given Eq. (11), we can derive the following formula, similar to Eq. (10): where r is the variational distribution to approximate the conditional distribution. By using the two variational distributions q and r, the problem (8) can be relaxed to Eq. (13), i.e. maximizing the variational lower bounds. (13) In this section, we present the experiments conducted on a real-world dataset to evaluate the performance of the proposed MED-TEX against the state-of-the-art methods. Architectures and settings of MED-TEX. For the teacher and student, we adopt a deep architecture with 4 block CNN layers, where each block consists of a convolutional layer, batch normalization, maxpooling and ReLU activation. Due to the smaller number of filters, the size of the student model is much (226 times) smaller than the teacher, i.e., 1.7k parameters of the student versus 390.5k parameters of the teacher. We empirically figure out that the current teacher architecture works well with our fundus dataset. However, it is important to note that our framework is general enough to be applied to various teacher and student architectures. For the explainer, we adopt auto encoder with skip connections. The last layer of explainer is 1 × 1 convolution with sigmoid activation. Dataset. We conducted our experiment on a fundus dataset 1 with normal or abnormal 2 class. Finally, we have 1873 images in total, which consists of 1073 normal and 800 abnormal images. For the abnormal images, there are 200 of them with fine-grained lesion segmentation. The dataset is split into the training (773 normal and 500 abnormal images) and testing (300 normal and 300 abnormal images) sets. All the 200 images with lesion segmentation are in the testing set. We simulate less data scenario to learn by reducing the number of training images, i.e., 25% and 50% training images are used, denoted as Fundus-25% and Fundus-50%, respectively. Compared methods. To our knowledge, there is no existing method that solves the exact same problem as ours. Thus, we individually compare our MED-TEX to knowledge distillation methods (e.g., KD [2] and VID [7] ) and model interpretation methods (e.g., hard attention using Gumbel-softmax trick [5] , soft attention [6] , Grad-CAM [3] and L2X [4] ). These comparison methods use ResNet18 backbone, which has significantly more parameters than the combination of explainer and student. To evaluate effectiveness of intermediate layer losses, we compare MED-TEX with its variant without information transfer losses (Eq. 4), denoted as MED-EX. The importance of explainer is illustrated by using another variant, i.e., the student (only) without the explainer and intermediate layer losses, which is only trained on the input image X. All models are trained by using Adam with learning rate 0.001, λ = 0.01 and batch size of 64. Evaluation metrics. Two metrics are introduced to be used to evaluate our framework. Post-hoc metric [4] compares the predictive distributions of the student given X and the teacher given X. In other words, we compute accuracy and f1 score by comparing between y T and y S for knowledge distillation evaluation. Note that these post-hoc metrics do not compared to human labels. Intersection over Union (IoU) compares between the highlighted image regions and the ground-truth lesion segmentation of abnormal images for interpretation evaluation. For a better comparison, we rank feature scores and select the number of pixels corresponding to the top K highest scores (e.g., where Θ topK indicates the selected pixels corresponding to the topK feature scores and X lesion denotes ground-truth lesion segmentation pixels. Results. For knowledge distillation performance shown in Table 1 , in terms of the post-hoc metric [4] , our method consistently outperforms other methods in term of both accuracy and F1 score. The Student (only) trained directly from raw input images cannot perform well. This suggests that the explainer with feature selection at pixel-level plays a critical role to guide the student to achieve better performance. The MED-TEX outperforms MED-EX, which indicates that it is beneficial to leverage the information in the intermediate layers. For model interpretation, Fig. 2 shows the IoU results in bar charts. Our explainer of MED-TEX achieves significantly higher IoU than others. Fig. 3 shows the visualization results of topK=6 highlighted image regions of different methods. Hard attention, soft attention, Grad-CAM and L2X can only give patch-based region selection maps, while our MED-EX and MED-TEX produces pixel-level selection scores. Our method highlights the lesion regions that well match the ground-truth lesion in Fig.3 . Moreover, MED-TEX clearly outperforms MED-EX because of the intermediate knowledge distillation losses as shown in Fig. 4 . 4. CONCLUSION In this paper, we have introduced our novel framework MED-TEX, which is a joint knowledge distillation and model interpretation framework that learns the significantly smaller stu- dent (compared to the teacher) and explainer models by leveraging the knowledge only from the pretrained teacher model. In our experiment, we show that MED-TEX outperforms several widely used knowledge distillation and model interpretation techniques. The future of digital health with federated learning Distilling the knowledge in a neural network Grad-CAM: Visual explanations from deep networks via gradient-based localization," in ICCV Learning to explain: An information-theoretic perspective on model interpretation Categorical reparameterization with gumbel-softmax Show, attend and tell: Neural image caption generation with visual attention Variational information distillation for knowledge transfer Segmenting neuronal structure in 3D optical microscope images via knowledge distillation with teacher-student network," in ISBI Additional Visual Results of MED-TEX: (Lesion, Selected and K indicate for expert segmentations, feature selection scores and topK x1024, respectively)