key: cord-0751728-724br56j authors: Chen, Xiaocong; Yao, Lina; Zhou, Tao; Dong, Jinming; Zhang, Yu title: Momentum contrastive learning for few-shot COVID-19 diagnosis from chest CT images date: 2021-01-16 journal: Pattern Recognit DOI: 10.1016/j.patcog.2021.107826 sha: fcf5356bb3bab66418e25da5372fb866b25e3d67 doc_id: 751728 cord_uid: 724br56j The current pandemic, caused by the outbreak of a novel coronavirus (COVID-19) in December 2019, has led to a global emergency that has significantly impacted economies, healthcare systems and personal wellbeing all around the world. Controlling the rapidly evolving disease requires highly sensitive and specific diagnostics. While RT-PCR is the most commonly used, it can take up to eight hours, and requires significant effort from healthcare professionals. As such, there is a critical need for a quick and automatic diagnostic system. Diagnosis from chest CT images is a promising direction. However, current studies are limited by the lack of sufficient training samples, as acquiring annotated CT images is time-consuming. To this end, we propose a new deep learning algorithm for the automated diagnosis of COVID-19, which only requires a few samples for training. Specifically, we use contrastive learning to train an encoder which can capture expressive feature representations on large and publicly available lung datasets and adopt the prototypical network for classification. We validate the efficacy of the proposed model in comparison with other competing methods on two publicly available and annotated COVID-19 CT datasets. Our results demonstrate the superior performance of our model for the accurate diagnosis of COVID-19 based on chest CT images. The latest coronavirus, COVID-19, was initially reported in Wuhan, China toward the end of 2019 and has since spread rapidly around the globe, leading to a worldwide crisis. As an infectious lung disease, COVID-19 leads to severe acute respiratory distress syndrome (ARDS) and is accompanied by a series of side effects that include a dry cough, fever, tiredness, shortness of breath, etc. As of October 18th 2020, more than 39 million individuals around the world have been confirmed as having COVID-19, with a roughly 6.3% case fatality rate, according to the World Health Organization. 1 So far, no effective treatment for COVID-19 has been found. One of the major hurdles is the lack of efficient diagnostic methods. Therefore, an accurate and rapid diagnosis platform is urgently required to conduct COVID-19 screening and prevent its further spread. Currently, most tests are based on real-time reverse transcriptase polymerase chain reaction (RT-PCR). However, each RT-PCR test can take several hours to produce results. With the current spread rate of COVID-19, this is not acceptable. Further, the limited number of test kits exacerbates the situation [1] [2] [3] . Recent studies also show that the RT-PCR suffers from low sensitivity and accuracy, often requiring repeated entries [4, 5] . This prevents patients from being confirmed in a timely manner, increasing the potential risk of spreading. In order to address these challenges, scientists around the world are trying to develop new diagnostic systems. Some studies [6, 7] have demonstrated that chest computed tomography (CT) imaging can help in diagnosing COVID-19 rapidly. Salehi et al. [8] concluded that chest CT imaging is sensitive for diagnosing COVID-19 even when patients do not have clinical symptoms. Specifically, three typical radiographic features, including consolidation, pleural effusion and ground class opacification, can be easily observed from the CT images of COVID-19 patients [9, 10] . With this in mind, several methods based on chest CT images been developed for diagnosing COVID-19. For instance, some studies used a 3D CNN to diagnosis of COVID-19 from chest CT scans [11] . Mei et al. [12] adopted ResNet to rapidly iden- Fig. 1 . The overall architecture of our approach. Top: The pre-training stage, which includes data augmentation and representation learning. The pretext task is an instance discrimination task. Bottom: Few-shot classification with 2-way, 1-shot example. For classification, the support images and query image are encoded by the pre-trained encoding network. Query sample embeddings are compared with the centroid of training sample embeddings and used to fine-tune the pre-trained encoder. tify COVID-19. Besides diagnosis, several works also used the segmentation techniques for detection [13, 14] . However, all existing methods are trained using the limited samples available from a small number of patients and may not generalize well to new patients. It is well-known that a lack of labelled training data is a common challenge since deep learning methods generally require a large volume of data for accurate training. Significant research efforts have been dedicated to alleviating this problem through, for example, data augmentation or generative adversarial networks (GANs) [15] [16] [17] [18] . However, these methods are highly sensitive to parameter selection. Hand-tuned data augmentation methods like rotation may lead to overfitting [19] , images generated by a GAN cannot simulate the real patient data which, may introduce unpredictable bias in the testing phase [15] . Recently, few-shot learning attracted significant attention in medical image analysis. In general, few-shot learning aims to leverage existing data to classify new tasks from similar domains. The basic workflow for few-shot learning is to pre-train an embedding network on a large dataset (e.g. ImageNet), then fine-tune the weights of this network, and finally apply it to a small unseen dataset [20, 21] . However, the performance is only marginally improved this way. One reason lies in that ImageNet contains a broad range of categories and pre-training on this dataset often introduces irrelevant information, which does not help in learning effective embeddings for improving lung-specific feature representation. In addition, pre-training on ImageNet incurs a high computational cost; for example, ImageNet-1B typically required over 50 GPU days. To address this challenge, we develop an end-to-end trainable deep few-shot learning framework that can provide accurate predictions with minimal training on chest CT images. Specifically, we first use the instance discrimination task to enforce model to discriminate two images are the same instance or not. Different views of the same images are then generated to augment the original dataset. As our goal at this stage is to increase variances other than discrimination, we are able to effectively avoid the disadvantages of data augmentation mentioned previously. We then deploy a self-supervised strategy [22] powered with momentum contrastive training to further boost the performance. The key idea is to build a dynamic dictionary to perform (key, query) look-up, where the keys are sampled from data and encoded by the encoder. However, the key in the dictionary is noisy and inconsistent due to the back-propagation [23] . We apply the momentum mechanism to mitigate this effect by updating the key and query encoders at different scales. Finally, we utilize two public lung datasets to pre-train an embedding network and employ the prototypical network [24] to conduct the few-shot classification, which learns a metric space for classification by measuring the distances to the derived prototypical representation of each class. Extensive experiments on two new datasets demonstrate that our model provides a promising tool for quick COVID-19 diagnosis with very limited available training data. Due to the shortage of annotated COVID-19 CT images, normal classification methods may not work properly. As such, we formulate COVID-19 diagnosis as a few-shot classification problem. Fewshot learning is designed for cases in which only a few samples of new class are available for classifier training. It can be defined as a M-way, C-shot episodic task [25] , where M represents the number of classes and C represents number of samples available for each class. The training set, which has never been seen before, can be where d is the number of samples in the dataset. We randomly select the support set and query set from D : (i) The support set S can be partially or fully made up of M classes but only contain C + 1 samples. (ii) We randomly select one sample from the C + 1 samples to form the test set (query set). Hence, COVID-19 diagnosis can be represented as a two-way, C-shot learning problem. In this section, we will introduce our proposed self-supervised COVID-19 diagnosis method. The overall flowchart is illustrated in Fig. 1 . We will describe the three major components of our method, which include data augmentation, representation learning and the few-shot classification. Three possibilities for random cropping. Dashed boxes are augmented views. Crops A, C, E will have random color distortions applied, while B, D, F will not change if method (2) is chosen. All the cropped sections will be resized back to the original input image size. For the instance discrimination task, the goal, given B, D, F is to determine whether or not A, C, E are in the same instance. Data augmentation has been widely used in unsupervised representation learning and supervised learning [26, 27] . A few existing approaches define the contrastive classification task as changing the structure of images. For instance, Hjelm et al. [28] and Bachman et al. [29] used global-to-local view for contrastive learning as shown in the first example in Fig 2 . Meanwhile, Oord et al. [22] and Henaff et al. [30] achieved neighbor prediction using the adjacency view (middle example, Fig 2 ) . An overlapping view of the two approaches can be seen in third image of Fig 2 . In this study, we apply a stochastic data augmentation T which randomly transfers a given example image x into two different views, denoted as x i , x j . We consider the pair x i , x j as positive. Further, we apply two simple augmentation strategies in sequence: (1) random cropping, followed by a resizing operation back to the original size with random flipping; or (2) random cropping with color distortions followed by a resizing operation. When a new image is fed into the model, one of the above methods is randomly selected for augmentation. This process is repeated twice to generate two different views. Note that we implement color distortions using the torchvision 2 package in PyTorch [31] . Using contrastive learning to learn visual embeddings was first explored by Hadsell et al. [32] . Given an image set {I = i 1 , . . . , i p } , x i ∈ R d , the goal of the task is to find a mapping func- where s (·, ·) is a function used to measure the similarity between two inputs. G is designed for dimension reduction and representation learning. Finally, x + , x − represent the positive and negative samples, where x + is similar to x and x − is dissimilar. It is worth mentioning that the contrastive learning is a type of unsupervised learning. A simple framework for contrastive learning was proposed by Chen et al. [33] . Specifically, the representations are learned by maximizing the agreement between differently augmented views x i , x j of the same data example x via a contrastive loss in the latent space. We adopt this framework in our model. Specifically, our representation learning stage consists of three modules: the encoder, projection head and, contrastive loss function. Encoder The neural network based encoder f (·) can extract representations from the augmented images. This framework is flexible for adopting any type of network architecture without constraints. In this study, we adopt ResNet [34] to obtain the repre- of the average pooling layer. Projection head The projection head g(·) is a function that can map the resulting representation into the application space of the contrastive loss. The most common projection head used is the multilayer perceptron (MLP) with one hidden layer [33] . In this case, we can express the z i (as well as z j ) as: where W 1 , W 2 are the weights of the hidden layer and output layer, respectively. The σ (·) is the non-linear ReLU activation function, which can be defined as: We will examine the effectiveness of this projection head in Section 4 . The contrastive loss function is defined for the contrastive pre-text task. It was first proposed by Hadsell et al. [32] and is used to calculate the value when the query is similar to the positive key and dissimilar for all other keys. In this manuscript, we only consider the instance discrimination task [35] . Given a set { x k } including a positive pair x i , x j , the contrastive task aims to identify x j in set { x k } k = i for a given x i . We define the contrastive task on pairs of augmented images from a randomly selected minibatch with N samples. The augmentation process results in 2 N data points. To create the contrastive task, we need enough negative samples to construct the loss function. Similar to Doersch et al. [36] , we treat the other 2 N − 2 examples as the negative samples. The similarity function s (·, ·) can be defined as the cosine similarity: where v , u are two vectors. Based on this, we can define the loss function for a pair of positive samples (i, j) as: Here 1 k = i ∈ { 0 , 1 } is the indicator which has a value of 1 when k = i and 0 otherwise, and τ is the temperature parameter. This loss is known as the normalized temperature-scaled cross entropy loss [22, 29] . However, Eq. (5) only considers the positive samples and ignores negative samples. Note that, the margin based contrastive loss function [32] has the same problem only considering about the positive keys. This may lead to potential bias. To avoid this, we introduce the momentum mechanism into our model. Contrastive learning can also be expressed as training an encoder to conduct a dictionary lookup task. Consider an encoded query q and encoded samples x i , . . . , x k , which are the keys of the dictionary. If the query q is similar to the sample x + , there is a match. For the negative samples x − , there is no match in the dictionary. Based on this definition, He et al. proposed an unsupervised learning-based framework MoCo [23] , by adopting contrastive learning. Based on the above definition, the goal of contrastive learning is to build a discrete dictionary for high-dimensional continuous inputs. The core of MoCo is to maintain a dictionary with a queue. The benefit of this is that the encoder can reuse the encoded keys from the previous mini-batch. In addition, the dictionary can be much larger than the mini-batch and easy to adjust. As the number of samples that can be included in the dictionary is fixed, once the dictionary is full, it will progressively remove the oldest records. In this way, the consistency of the dictionary can be maintained as the oldest samples are often out-of-date and inconsistent with the new entries. Another approach, called Memory Bank [35] , tries to store the historical records of the encoded samples. This approach maintains a bank of all the representations of the dataset. The dictionary then randomly samples from the memory bank directly for each mini-batch without back-propagation. However, this method will lead to inconsistency when sampling. To overcome this, back-propagation should be conducted to keep the sampling step up-to-date. A simple solution is to copy the key encoder f k from the query encoder f q without the gradient. However, the encoder changes constantly, which can lead to a noisy key representation and poor results. The momentum contrast was introduced to address this problem, using a different method to update the gradient for f k : where θ k is the parameter for f k , θ q is the parameter for f q and m ∈ [0 , 1) is the momentum coefficient. We use back-propagation to update the parameter θ q and use Eq. (6) to update θ k . Benefiting from the momentum coefficient, the update of θ k is smoother than θ q . According to the different update strategies, the query and key will eventually be encoded by different encoders. Based on the above discussion, we use the dictionary as a queue to allow the encoder to reuse the previous encoded sample. The loss function for the pre-trained model can be written as: Different from Eq. (5) , here we need to consider the queue and the negative cases, so we slightly modify the loss function to fulfil this requirement by introducing the positive examples k + and negative examples k − , where q k = k + ∪ k − . In the instance discrimination pre-text task, a positive pair is formed when a query q and a key k are augmented from the same sample; otherwise, a negative pair is created. Once the pre-training step is finished, we extract the pre-trained encoder f (·) and integrate it into our classification module. It is worth mentioning that the triple loss [37] is another popular loss function that considers both positive and negative examples. However, the triple loss does not converge easily and is time-consuming. Hence, it is normally only used in identification [38, 39] or fine-grained image classification tasks. For our task, using the contrastive loss is adequate as we are dealing with instance discrimination task. Another step in our workflow is classification. In this stage, meta-learning is applied to fine-tune the pre-trained encoder to fit the class changes required by few-shot learning. Then we use Prototypical Networks [24] for few-shot classification. The prototypical network learns an embedding that maps all inputs into a mean vector c in the latent space to represent each class. The goal of the pre-trained encoder is to ensure that similar images are close and dissimilar images are separate in the latent space. The prototypical network has a similar goal, so it is used to fine-tune our pre-trained encoder. For class m, the centroid embedding features can be written as: where ψ (·) is the embedding function from the prototypical network. As the prototypical network is a metric based learning method, we use the Euclidean distance to produce the distribution for all classes for a query q . Eq. (9) is based on the softmax function over the distance between a query set's embedding and the features of the class. The loss function for this stage can be defined as: Algorithm 1 shows the whole pre-training workflow of our Algorithm 1: Training algorithm for the pre-training. Select two data augmentation functions from T : t , t ; We evaluated our proposed model using two publicly available annotated COVID-19 CT slices datasets: (1) COVID-19 CT 3 and (2) a dataset provided by the Italian Society of Medical and Interventional Radiology 4 and preprocessed by MedSeg. 5 It is worth mentioning that there is no overlap between COVID-19 CT and MegSeg as they come from different countries. When dividing the support and query sets for classification, we divided the datasets at a patient-level instead of CT level to avoid any possible overlap. The basic statistics for the COVID-19 CT dataset and MegSeg are 3 https://github.com/UCSD-AI4H/COVID-CT 4 https://www.sirm.org/category/senza-categoria/covid-19/ 5 http://medicalsegmentation.com/covid19/ summarized in Table 1 . We combined the two datasets for testing. Note that all CT slices were resized to 512 × 512 using opencv2. 6 A proper pre-training is required for our proposed model. Different from other existing methods, such as Self-Trans [20] , that used ImageNet to pre-train the model, we adopted DeepLesion [40] and the Lung Image Database Consortium Image Collection (LIDC-IDRI). 7 DeepLesion contains over 32,0 0 0 lung CT images while LIDC-IDRI includes 244,617 ones. Both datasets are public and focus on lung diseases. We used the two datasets without labels to pre-train the encoder network. For pre-training, we used the SGD optimizer with a weight decay of 0.0 0 01 and momentum of 0.9. The momentum update coefficient was 0.999. The mini-batch size was set to 256 in eight GPUs. The number of epochs was 200. The initial learning rate was 0.03, which was then multiplied by 0.1 after 120 and 160 epochs, as described in [35] . ResNet-50 was used as the encoder. The twolayer MLP projection head included a 2048-D hidden layer with a ReLU activation function. The weights were initialized by using He initialization [41] , and the temperature parameter τ was set to 0.07. For the classification stage, we followed the default settings of the prototypical net. The experiments were conducted on eight GPUS which includes six NVIDIA TITAN X Pascal GPUs and two NVIDIA TITAN RTX. We evaluated model performance using four metrics: (i) Accuracy, which measures the percentage of correctly classified samples over the whole dataset; (ii) Precision, which measures the percentage of true positives (TP) over all predicted positive samples; (iii) Recall, used to measure the percentage of TPs over all positive samples; and (iv) Area-under-the-curve (AUC) which measures the relation between FPs and TPs. We trained and tested each of the compared methods on COVID-19 CT and MegSeg dataset using 10fold cross-validation at a patient-level with the cross-entropy loss function. The experimental results are summarized in Table 2 . We found that the designed two-way, one-shot model yielded very similar performance to ResNet-50. In addition, we found that the obtained classification performance is worse when the model is pre-trained on ImageNet. As discussed previously, an extra step may be required to conduct transfer learning from common items to lung CT slices. As previously mentioned, our method used few-shot learning. We were thus interested to see how varying the number of shots would affect the model performance. Accordingly, we conducted an experimental analysis to explore the relationship between the classification performance and the number of shots. The results are shown in Table 3 , where ResNet-50 is used as a baseline method for the comparison. As can be seen, the classification performance of our model is gradually improved with the increase in the number of shots. Specifically, our model achieved significantly improved performance when using four shots compared with one shot and outperformed ResNet-50, but no obvious further improvement was observed when using more than five shots. These results indicate that the pre-trained encoder effectively captured the features from unknown images to improve the classification performance. Additionally, we provide visualizations of the features learned by different methods including our method, Pretrain with ImageNet, DenseNet-121 and ResNet-50 in Fig 3 . Here, both DenseNet-121 and ResNet-50 were directly trained on the COVID-19 dataset described in Table 1 . As can be seen, our method learned more features that focused on the lung area, improving the classification accuracy in comparison with approaches. Table 4 summarizes the training time taken by different methods for a comparison of the computational cost. As our method was not trained on the COVID-19 dataset, the corresponding training time was not available. The method trained on ImageNet took about 150 h. In this section, we conducted extensive ablation studies to demonstrate the importance of each component in our model. The default setting of our method was two-way, one-shot and ResNet-50 used the same setting as in the previous section. We investigated the following research questions: (1) How would data augmentation and projection head affect the performance? (2) How important is the fine-tuning stage? (3) How would the resizing operation affect the performance? (4) Is the result significantly affected when using a different encoder? First, we explored the role of the data augmentation and projection head. We conducted the experiments on our model without augmentation and without projection head, respectively. The results summarized in Table 5 show that data augmentation had a significant effect on the model performance while the projection head yielded only a slight improvement. One possible reason why the projection head was not able to provide an obvious improvement would be that it was only used to extract the most important information from the similarity vector. As such, it simply worked as a filter without introducing additionally useful features. In addition, we also investigated that the impact of different data augmentation strategies. Specifically, we compared the classification results between random-cropping of using three augmentation strategies and only using a single one. All the com- Table 2 Performance comparison between our proposed model and other methods. Our model and the method pretrained on ImageNet use a two-way, one-shot strategy. pared data augmentation methods are shown in Fig. 2 . In Table 6 , we summarize the classification results and use AB-cropping, CDcropping, and EF-cropping to represent each of the three strategies. The results demonstrate that the model performance de- Table 5 The effect of data augmentation and projection head. Table 9 Effects of using different encoders on model performance. Moreover, we also examined the effects of the fine-tuning process on the model performance. To do so, we first pre-trained the embedding network and modified the few-shot classification stage by replacing the prototypical network with a linear classifier with frozen features. We directly applied the linear classifier into the learned embedding network without any update on weights. The results are summarized in Table 7 which show that the fine-tuning process can significantly improve the performance. In addition, we also investigated the impact of the resizing operation during the data augmentation process. To this end, we compared the model performance with and without the resizing operation (see Table 8 ). As can be seen, the resizing operation slightly affected the performance. Moreover, as the cropping operation may generate different-sized images, the resizing operation is necessary to ensure that all these generated images can be fed into the neural network for model training. Finally, we examined the effect of using different encoders on model performance. We used the same settings as mentioned in Section 4.2 , but changed the encoder network from ResNet-50 to ResNet-152, DenseNet-161, and VGG-16 for performance comparison. The results reported in Table 9 show that ResNet-50 achieved the best performance. This justified the use of ResNet-50 as the encoder in our proposed model. CT imaging is attracting increasing attention as a screening tool for COVID-19. It provides visualization for monitoring disease progression and can help to evaluate the severity. However, the lack of annotated CT scans is a significant challenge in CT imaging-based medical studies. In this work, we proposed a new deep-learning based method that can be used for the automatic diagnosis of COVID-19 with limited samples. Moreover, we demonstrated that our method achieved superior performance over ResNet-50 when the number of available samples is larger than three. ResNet-50 is a well-known and widely used supervised learning model for med-ical image analysis. As our developed model used a self-supervised strategy based on unsupervised learning, the fact that it can outperform than ResNet-50 is remarkable. We expect that our method will be useful for other medical imaging analysis tasks facing the same data shortage problem. In the future, we plan to apply the proposed method to more COVID-19 datasets to validate its generalizability. Moreover, we will also investigate how to use knowledge distillation to reduce the size of learned embedding and further increase the classification accuracy. INF-Net: automatic COVID-19 lung infection segmentation from CT images Automatic detection of coronavirus disease (COVID-19) using X-ray images and deep convolutional neural networks Automatically discriminating and localizing COVID-19from community-acquired pneumonia on chest X-rays Diagnosis of the coronavirus disease (COVID-19): RRT-PCR or CT? Correlation of chest CT and RT-PCR testing in coronavirus disease 2019 (COVID-19) in China: a report of 1014 cases Chest CT findings in coronavirus disease-19 (COVID-19): relationship to duration of infection COVID-19): role of chest CT in diagnosis and management Coronavirus disease 2019 (COVID-19): a systematic review of imaging findings in 919 patients Clinical features of patients infected with 2019 novel coronavirus in A review of the 2019 novel coronavirus (COVID-19) based on current evidence Artificial intelligence distinguishes COVID-19 from community acquired pneumonia on chest CT Artificial intelligence-enabled rapid diagnosis of patients with COVID-19 Deep learning-based model for detecting 2019 novel coronavirus pneumonia on high-resolution computed tomography Residual attention U-Net for automated multiclass segmentation of COVID-19 chest CT images Data augmentation using learned transformations for one-shot medical image segmentation Augmenting data when training a CNN for retinal vessel segmentation: how to warp Deep feature augmentation for occluded image classification Tackling mode collapse in multi-generator GANs with orthogonal vectors Improving data augmentation for medical image segmentation Sample-efficient deep learning for COVID-19 diagnosis based on CT scans Topological optimization of the densenet with pretrained-weights inheritance and genetic channel selection Representation learning with contrastive predictive coding Momentum contrast for unsupervised visual representation learning Prototypical networks for few-shot learning Matching networks for one shot learning Large scale adversarial representation learning Decaf: a deep convolutional activation feature for generic visual recognition Learning deep representations by mutual information estimation and maximization Learning representations by maximizing mutual information across views Data-efficient image recognition with contrastive predictive coding Pytorch: an imperative style, high-performance deep learning library Dimensionality reduction by learning an invariant mapping A simple framework for contrastive learning of visual representations Deep residual learning for image recognition Unsupervised feature learning via non-parametric instance discrimination Multi-task self-supervised visual learning FaceNet: a unified embedding for face recognition and clustering Deep features for person re-identification on metric learning Face re-identification challenge: are face recognition models good enough? Deep lesion graphs in the wild: relationship learning and organization of significant radiology image findings in a diverse large-scale lesion database Delving deep into rectifiers: surpassing human-level performance on imagenet classification Grad-CAM: visual explanations from deep networks via gradient-based localization The authors declare that they have no known competing financial interests or personal relationships that could have appeared to influence the work reported in this paper.