key: cord-0609293-1ab7d7o8 authors: Chen, Xiaocong; Yao, Lina; Zhou, Tao; Dong, Jinming; Zhang, Yu title: Momentum Contrastive Learning for Few-Shot COVID-19 Diagnosis from Chest CT Images date: 2020-06-16 journal: nan DOI: nan sha: 289336027ed7593868e2e7107c953c06bf82343b doc_id: 609293 cord_uid: 1ab7d7o8 The current pandemic, caused by the outbreak of a novel coronavirus (COVID-19) in December 2019, has led to a global emergency that has significantly impacted economies, healthcare systems and personal wellbeing all around the world. Controlling the rapidly evolving disease requires highly sensitive and specific diagnostics. While real-time RT-PCR is the most commonly used, these can take up to 8 hours, and require significant effort from healthcare professionals. As such, there is a critical need for a quick and automatic diagnostic system. Diagnosis from chest CT images is a promising direction. However, current studies are limited by the lack of sufficient training samples, as acquiring annotated CT images is time-consuming. To this end, we propose a new deep learning algorithm for the automated diagnosis of COVID-19, which only requires a few samples for training. Specifically, we use contrastive learning to train an encoder which can capture expressive feature representations on large and publicly available lung datasets and adopt the prototypical network for classification. We validate the efficacy of the proposed model in comparison with other competing methods on two publicly available and annotated COVID-19 CT datasets. Our results demonstrate the superior performance of our model for the accurate diagnosis of COVID-19 based on chest CT images. cific diagnostics. While real-time RT-PCR is the most commonly used, these can take up to 8 hours, and require significant effort from healthcare professionals. As such, there is a critical need for a quick and automatic diagnostic system. Diagnosis from chest CT images is a promising direction. However, current studies are limited by the lack of sufficient training samples, as acquiring annotated CT images is time-consuming. To this end, we propose a new deep learning algorithm for the automated diagnosis of COVID-19, which only requires a few samples for training. Specifically, we use contrastive learning to train an encoder which can capture expressive feature representations on large and publicly available lung datasets and adopt the prototypical network for classification. We validate the efficacy of the proposed model in comparison with other competing methods on two publicly available and annotated COVID-19 CT datasets. Our results demonstrate the superior performance of our model for the accurate diagnosis of COVID-19 based on chest CT images. The latest coronavirus, COVID-19, was initially reported in Wuhan, China toward the end of 2019 and has since spread rapidly around the globe, leading to a worldwide crisis [1] . As an infectious lung disease, COVID-19 leads to severe acute respiratory distress syndrome (ARDS) and is accompanied by a series of side effects that include a dry cough, fever, tiredness, shortness of breath, etc. As of May 24th 2020, more than 5.4 million individuals have been confirmed as having COVID-19, with a roughly 6.3% case fatality rate around the world, according to the World Health Organization 1 around the world. So far, no effective treatment for COVID-19 has been found. One of the major hurdles is the lack of efficient diagnosis methods. Therefore, an accurate and rapid diagnosis platform is urgently required to conduct screening and prevent its further spread. Currently, most tests are based on real-time reverse transcriptase polymerase chain reaction (RT-PCR). However, each PT-PCR test can take several hours to produce results. With the current spread rate of COVID-19, this is not acceptable. Further, the limited number of test kits makes this situation even more serious (Shan et al., 2020; Narin et al., 2020) . Recent studies show that the RT-PCR suffering from low sensitivity and low accuracy, repeated entries are required (Long et al., 2020; Ai et al., 2020) . This infers that patients will not be able to be confirmed on time which increases the potential risk of spreading. In order to address these challenges, scientists around the world are trying to develop new diagnostic systems. Some studies Li and Xia, 2020) have demonstrated that chest computed tomography (CT) imaging can help in diagnosing COVID-19 rapidly. Salehi et al. (Salehi et al., 2020) concluded that chest CT imaging is sensitive when diagnosing COVID-19 even when patients do not have clinical symptoms. Specifically, three typical radiographic features including consolidation, pleural effusion and ground class opacification can be easily observed Wang et al., 2020; Shi et al., 2020a) on COVID-19 patients CT images. With this in mind, several methods based on chest CT images been developed for diagnosing COVID-19. For instance, Butt et al. (Butt et al., 2020 ) adopted a convolutional neural Network (CNN) to classifier patient's CT images. In addition, there are a few works use the 3D CNN to conduct the diagnosis of COVID-19 as well based on chest CT scans (Gozes et al., 2020; Zheng et al., 2020) . Mei et al. (Mei et al., 2020) adopted the ResNet to conduct the rapid diagnosis of the COVID-19. Besides the diagnosis, lots of works are using the segmentation technique to conduct the detection (Chen et al., 2020a; Shi et al., 2020b; Chen et al., 2020d) . All those existing methods are trained based on the limited available samples which have small number of patients and may not have capability to generalize to new patients. It is well-known that, the lack of labelled training data is the common problem, while deep learning based methods generally require a large volume of data to accurately train the models. Many research efforts have been sought for alleviating this problem such as data augmentation or through generative adversarial network (GAN) (Zhao et al., 2019; Akkus et al., 2017; Oliveira et al., 2017; Pereira et al., 2016) .However, these methods are highly sensitive with the parameter selection. Hand-tuned data augmentation methods like rotation may lead to over-fitting (Eaton-Rosen et al., 2018) and the generated images by GAN can not simulate the real patient which may introduce unpredictable bias in testing phase (Zhao et al., 2019) . Recently, few-shot learning has been attracting much attention in medical image analysis. In general, few-shot learning aims to leverage existing data to classify new tasks from similar domains. The basic workflow for few-shot learning is first pretraining an embedding network on a large dataset (e.g. ImageNet) and then finetuning the weights of this network, finally applying it into a unseen small dataset . However, the performance is limitedly improved. One reason lies in the ImageNet contains a broad range of categories of images and pre-train on ImageNet usually might bring in unrelevant information, so as not to learn a effective embeddings for improving lung-specific feature representation. On the other hand, pre-training on ImageNet causes high computational cost, for example, ImageNet-1B normally required more than 50 GPU days. To address this challenge, we develop an end-to-end trainable deep few-shot learning framework to make an accurate prediction with minimal training Chest CT Images. Specifically, we fist use the instance discrimination task to enforce model to discriminate two images are the same instance or not. We generate different views of the same images to augment the original dataset. As the goal at this stage is increasing variances other than discrimination, we can avoid the disadvantages of data augmentation mentioned previously. We then deploy a self-supervised strategy (Oord et al., 2018) powered with momentum contrastive training to further boost the performance. The key idea is to build a dynamic dictionary to perform key, query look-up where the keys are sampled from data and encoded by the encoder. However, the key in dictionary is noisy and inconsistent due to the back-propagation (He et al., 2019) . The momentum mechanism is applied to mitigate this effect by updating key encoder and query encoder in different scales. Finally, we utilized two public lung datasets to pre-train an embedding network and employ the prototypical network (Snell et al., 2017) to conduct the few-shot classification, which learns a metric space where the classification can be performed by measuring the distances to the derived prototypical representation of each class. The extensive experiments on two new datasets demonstrate that our model provides a promising tool for quick COVID-19 diagnosis with very limited available training data. Due to the shortage of the annotated COVID-19 CT images, the normal classification methods may not able to work properly. Based on that, we formulate the COVID-19 diagnosis problem as a few-shot classification problem. The few-shot learning is designed for the case which only a few samples available in a new class on a classification task. The few-shot learning can be defined as a M -way, C-shot episodic task (Vinyals et al., 2016) where M represents the number of classes and C represents number of samples available for each class. The training set which never seen before can be represented as d is the number of samples in this dataset. We randomly selected the support set and query set from D: i) The support set S can be partially or fully made up of M classes but only contain C +1 samples each. ii) We randomly select 1 sample from the C + 1 samples to form the test set (query set). Hence, COVID-19 diagnosis can be represented as a two-way, C-shot learning problem. In this section, we will introduce our proposed self-supervised COVID-19 diagnosis method. The overall flowchart is illustrated in Fig. 1 . We will describe the three major components of our method which include data augmentation, representation learning and the few-shot classification. Data augmentation has been widely used in unsupervised representation learning and supervised learning (Parkhi et al., 2012; Donahue and Simonyan, 2019; Donahue et al., 2014) . A few existing approaches define the contrastive classification task as changing abd image's structure. For instance, Hjelm et al. (Hjelm et al., 2018) In this study, we apply a stochastic data augmentation T which will randomly transfer a given example image x into two different views denoted as x i , x j . We consider the pair x i , x j as positive. In this study, we apply two simple augmentation strategies in sequence: 1) random cropping, followed by a resizing operation back to the original size with random flipping. 2) random cropping with color distortions followed by a resizing operation. When a new image is fed into the model, one of the above methods to perform is randomly selected for augmentation. This process is repeated twice to generate two different views. Note that we implement color distortions using the torchvision 2 package in PyTorch (Paszke et al., 2019) . (2) is chosen. All the cropped sections will be resized back to the original input image size. For the instance discrimination task, the goal, given B, D, F,is to, determine whether or not A, C, E are in the same instances. Using contrastive learning to learn visual embeddings was first explored by Hadsell et al. (Hadsell et al., 2006) . The task can be defined as: Given an where s(·, ·) is a function used to measure the similarity between two inputs. G is designed for dimension reduction and representation learning. Finally, x + , x − represent the positive and negative samples which means x + being similar to x and x − dissimilar. It is worth to mentioned that, the contrastive learning is a type of unsupervised learning. A simple framework for the contrastive learning was proposed by Chen et al. (Chen et al., 2020b) . The framework learns the representations by maximizing the agreement between differently augmented views x i , x j of the same data example x via a contrastive loss in the latent space. We adopt this framework in our model. Specifically,our representation learning stage consists of three modules: encoder, projection head and the contrastive loss function. Encoder. The neural network based encoder f (·) can extract representations from the augmented images. This framework is flexible to adopt any type of network architecture without constraints. In this study, we adopt ResNet (He et al., 2016) to obtain the representation h i is the R d output of the average pooling layer. Projection Head. The project head g(·) is a function that can map the resulting representation int oapplication space of the contrastive loss. The most common projection head used is the multilayer perceptron (MLP) with one hidden layer (Chen et al., 2020b,c) . In this case, we can express the z i (as well as z j ) as: where W 1 , W 2 are the weights of the hidden layer and output layer respectively. The σ(·) is the non-liner ReLU activation function which can be defined as: We will examine the effectiveness of this projection head in Section 4. Contrastive Loss Function. The contrastive loss function is defined for the contrastive pre-text task. We only consider the instance discrimination task (Wu et al., 2018) in this study. Given a set {x k } including a positive pair x i , x j , the contrastive task aims to identify x j in set {x k } k =i for a given x i . We define the contrastive task on pairs of augmented images from a randomly selected minibatch with N samples. The augmentation process results in 2N data points. To create the contrastive task, we need enough negative samples to construct the loss function. Similar to Chen et al. (Doersch and Zisserman, 2017) , we treat the other 2N − 2 examples as the negative samples. The similarity function s(·, ·) can be defined as the cosine similarity: where v, u are two vectors. Based on this, we can define the loss function for a pair of positive samples (i, j) as: Here 1 k =i ∈ {0, 1} is the indicator which have value of 1 when k = i and 0 otherwise, and τ is the temperature parameter. This loss is known as the normalized temperature-scaled cross entropy loss (Sohn, 2016; Bachman et al., 2019; Oord et al., 2018) . However, Eq.(5) only considers the positive samples and ignores negative samples. This may lead to potential bias. To avoid this, we introduce the momentum mechanism into our model. Contrastive learning can also be expressed as training an encoder to conduct a dictionary lookup task. Consider an encoded query q and encoded samples Based on the above definition, the goal of contrastive learning is to build a discrete dictionary for high-dimensional continuous inputs. The core of MoCo is maintaining a dictionary with a queue. The benefit of this is that the encoder can reuse the encoded keys from the previous mini-batch. In addition, the dictionary can be much larger than the mini-batch and easy to adjust. As the number of samples that can be included in the dictionary is fixed, once the dictionary is full, it will progressively remove the oldest records. In this way, the consistency of the dictionary can be maintained as the oldest samples are often out-of-date and inconsistent with the new entries. Another approach, called Memory Bank (Wu et al., 2018) , tries to store the historical records of the encoded samples. This approach maintains a bank of all the representations of the dataset. The dictionary then randomly samples from the memory bank directly for each mini-batch without back-propagation. However, this method will lead to inconsistency when sampling. To overcome this, back-propagation should be conducted to keep the sampling step up-to-date. A simple solution is to copy the key encoder f k from query encoder f q without the gradient. However, the encoder changes constantly which can lead to a noisy key representation and poor results. The momentum contrast was address to relief this problem, using a different method to update the gradient for f k : where θ k is the parameter for f k , θ q is the parameter for f q and m ∈ [0, 1) is the momentum coefficient. We use the back-propagation to update the parameter θ q and use Eq.(6) to update θ k . Benefiting from the momentum coefficient, the θ k 's update is smoother than θ q . Based on the different update strategies, the query and key will be encoded by different encoders eventually. Based on the above discussion, we use the dictionary as a queue to allow the encoder to reuse the previous encoded sample. The loss function for the pre-trained model can be written as: Different from Eq.(5), here we need to consider the queue and the negative cases, so we slightly modify the loss function to fulfil this requirement by introducing the positive examples k + and negative examples k − , where q k = k + ∪ k − . In the instance discrimination pre-text task, a positive pair is formed when a query q and a key k are augmented from the same sample; otherwise, a negative pair is created. Once the pre-training step is finished, we extract the pre-trained encoder f (·) and integrate it into our classification module. Another step in our workflow is classification. In this stage, meta-learning is applied to fine-tune the pre-trained encoder to fit the class changes required by few-shot learning. Then we use the Prototypical Networks (Snell et al., 2017) to conduct the few-shot classification. The prototypical network learns an embedding that maps all inputs into a mean vector c in the latent space to represent each class. The goal of the pre-trained encoder is to similar images close and dissimilar images separate in the latent space. The prototypical network has a similar aim so it to fine-tune our pre-trained encoder. For class m, the centroid embedding features can be written as: where ψ(·) is the embedding function from the prototypical network. As the prototypical network is a metric based learning method, we use the Euclidean distance to produce the distribution for all classes for a query q. Eq.(9) is based on the softmax function over the distance between a query set's embedding and the features of the class. The loss function for this stage can be defined as: Algorithm 1 shows the whole pre-training workflow of our model. We evaluate our proposed model using two publicly available annotated COVID-19 CT image datasets: (1) COVID-19 CT provided by Zhao et al. and (2) Algorithm 1: Training algorithm for the pre-training input : Batch size N, τ, f k , f q , g, T , q k Select two data augmentation functions from T : t, t ; Calculate the similarity using Eq.(4) ; 10 end 11 Update f k to minimize Eq. (7); focus on lung diseases. We use those two datasets without labels to pre-train the encoder network. When dividing the support and query sets for classification, we divide the dataset at a patient-level instead of CT level to avoid any possible over-fitting. We report the basic statistics for the COVID-19 CT dataset and MegSeg in Table 1 . We combine the two datasets for testing. Note that all CT images were resized to 512 × 512 using opencv2 6 . Non-COVID-19 397 0 For pre-training, we use the SGD optimizer with a weight decay of 0.0001 and momentum of 0.9. The momentum update coefficient is 0.999. The mini-batch size is set to 256 in eight GPUs. The initial learning rate is 0.03. The number of epochs is 200, and the learning rate is multiplied by 0.1 after 120 and 160 epochs, as described in (Wu et al., 2018) . The encoder is ResNet-50. The twolayer MLP projection head has a 2048-D hidden layer with a ReLU activation function. The weights are initialized using He initialization (He et al., 2015) , and the temperature parameter τ is set to 0.07. For the classification stage, we follow the default settings of the prototypical net. The experiments were conducted on eight GPUS which includes six NVIDIA TITAN X Pascal GPUs and two NVIDIA TITAN RTX. We evaluate our approach by using four metrics: i) Accuracy, which mea- Table 2 . We find that two-way, one-shot method achieves very similar result to the ResNet-50 trained on the COVID-19 dataset. In addition, we also provide the visualization of the learned feature on figure3. As aforementioned, our method is based on few-shot learning. We are interested in how the number of shots affect model's performance. Based on that, we conduct a few experiments to explore the relationship between the performance and the number of shots. We use the ResNet-50 as the baseline methods. The results Table 3 . In this section, we conduct the ablation studies to demonstrate the effectiveness of each component. The default setting is two-way, one-shot and the ResNet-50 use the same setting as previous section. In summary, we want to answer the following research questions: 1) What if pre-trained was conducted on ImageNet? 2) How data augmentation and projection head affect the performance? 3)How important of the fine-tuning stage? First, we are interested in the pre-training dataset selection. We thus conduct the experiments again using ImageNet for pre-training on default setting. The results can be found in Table 4 . As expected, we find that the performance is worse when the model is pre-trained by ImageNet. As discussed before, an extra step may be required to conduct transfer learning from common items to lung CT images. Next, we discover the role of the data augmentation and projection head played in our model. We've conducted the experiments on our model without projection head and without augmentation. The result can be found in Table 5 . As we can see, the data augmentation have the significant effect with the result procedure. Result can be found in Table 6 . Nowadays, CT imaging attracts more and more attention as a screening tool for diagnosing COVID-19. It provides a visualization for community to monitor patient's progression and can help to evaluate the severity of COVID-19 (Shan et al., 2020) . However, the lack of the annotated CT scans are the biggest challenge. In this study, we proposed a new deep-learning based method which can be used for automatic screening of COVID-19 with limited samples. And it proved that such method achieves superior performance than ResNet-50 when the number of available samples is larger than three. ResNet is a well-known and widely used supervised learning model on medical image area. As a selfsupervised method which belongs to unsupervised learning field, the results are better than ResNet would be remarkable. Correlation of chest ct and rt-pcr testing in coronavirus disease 2019 (covid-19) in china: a report of 1014 cases Deep learning for brain mri segmentation: state of the art and future directions Learning representations by maximizing mutual information across views Chest ct findings in coronavirus disease-19 (covid-19): relationship to duration of infection Deep learning system to screen coronavirus disease 2019 pneumonia Deep learning-based model for detecting 2019 novel coronavirus pneumonia on high-resolution computed tomography: a prospective study A simple framework for contrastive learning of visual representations Improved baselines with momentum contrastive learning Residual attention u-net for automated multi-class segmentation of covid-19 chest ct images Multi-task self-supervised visual learning Decaf: A deep convolutional activation feature for generic visual recognition Large scale adversarial representation learning Improving data augmentation for medical image segmentation Rapid ai development cycle for the coronavirus (covid-19) pandemic: Initial results for automated detection & patient monitoring using deep learning ct image analysis Dimensionality reduction by learning an invariant mapping Momentum contrast for unsupervised visual representation learning Delving deep into rectifiers: Surpassing humanlevel performance on imagenet classification Deep residual learning for image recognition Sample-efficient deep learning for covid-19 diagnosis based on ct scans Data-efficient image recognition with contrastive predictive coding Learning deep representations by mutual information estimation and maximization Clinical features of patients infected with 2019 novel coronavirus in wuhan, china Densely connected convolutional networks Artificial intelligence distinguishes covid-19 from community acquired pneumonia on chest ct Coronavirus disease 2019 (covid-19): Role of chest ct in diagnosis and management Diagnosis of the coronavirus disease (covid-19): rrt-pcr or ct? Artificial intelligence-enabled rapid diagnosis of patients with covid-19 Automatic detection of coronavirus disease (covid-19) using x-ray images and deep convolutional neural networks Augmenting data when training a cnn for retinal vessel segmentation: How to warp? Representation learning with contrastive predictive coding Cats and dogs Pytorch: An imperative style, highperformance deep learning library Brain tumor segmentation using convolutional neural networks in mri images Coronavirus disease 2019 (COVID-19): a systematic review of imaging findings in 919 patients Grad-cam: Visual explanations from deep networks via gradient-based localization Lung infection quantification of covid-19 in ct images with deep learning Review of artificial intelligence techniques in imaging data acquisition, segmentation and diagnosis for covid-19 Large-scale screening of covid-19 from community acquired pneumonia using infection size-aware classification Prototypical networks for few-shot learning Improved deep metric learning with multi-class n-pair loss objective Matching networks for one shot learning A review of the 2019 novel coronavirus (covid-19) based on current evidence Unsupervised feature learning via non-parametric instance discrimination Deep lesion graphs in the wild: relationship learning and organization of significant radiology image findings in a diverse large-scale lesion database Data augmentation using learned transformations for one-shot medical image segmentation Deep learningbased detection for covid-19 from chest ct using weak label