key: cord-0492065-7ldbo8e4 authors: Jiang, Yifan; Chen, Han; Han, David K.; Ko, Hanseok title: Few-shot Learning for CT Scan based COVID-19 Diagnosis date: 2021-02-01 journal: nan DOI: nan sha: e35cf6243d3ae7de1eeb174b48ce26cf603ffb6f doc_id: 492065 cord_uid: 7ldbo8e4 Coronavirus disease 2019 (COVID-19) is a Public Health Emergency of International Concern infecting more than 40 million people across 188 countries and territories. Chest computed tomography (CT) imaging technique benefits from its high diagnostic accuracy and robustness, it has become an indispensable way for COVID-19 mass testing. Recently, deep learning approaches have become an effective tool for automatic screening of medical images, and it is also being considered for COVID-19 diagnosis. However, the high infection risk involved with COVID-19 leads to relative sparseness of collected labeled data limiting the performance of such methodologies. Moreover, accurately labeling CT images require expertise of radiologists making the process expensive and time-consuming. In order to tackle the above issues, we propose a supervised domain adaption based COVID-19 CT diagnostic method which can perform effectively when only a small samples of labeled CT scans are available. To compensate for the sparseness of labeled data, the proposed method utilizes a large amount of synthetic COVID-19 CT images and adjusts the networks from the source domain (synthetic data) to the target domain (real data) with a cross-domain training mechanism. Experimental results show that the proposed method achieves state-of-the-art performance on few-shot COVID-19 CT imaging based diagnostic tasks. Coronavirus disease 2019 (COVID-19) [1] is an ongoing global pandemic that was declared by the World Health Organization (WHO) on 11 March 2020. It has already infected more than 40 million individuals and caused 1,119,369 death, as of 20 October 2020 [2] . COVID-19 is highly contagious and it spreads more readily compared to similar infectious diseases such as Middle East Respiratory Syndrome (MERS) or Severe Acute Respiratory Syndrome (SARS) [3] . To slow down the rapid transmission of this disease, it is necessary to detect the COVID-19 in an early stage of infection. With the emergence of deep learning, medical imaging area also benefited from effective feature representation capability of deep learning techniques [4, 5, 6, 7, 8] . However, applying deep learning on COVID-19 diagnosis is challenging due to lack of sufficiently large labeled data, particularly of COVID-19 CT data, as it involves high infection risk and the labeling process requiring experienced radiologists [9] . To enable effective deep learning based COVID-19 diagnosis, it is necessary to develop a novel approach capable of learning in few-shot conditions (only limited data is available). One promising approach that can tackle the above issue is to use synthetic data for model training. However, a model trained with synthetic data may not perform satisfactorily on real data when applied directly. This is because of the domain shift problem: synthetic data (source domain) may not necessarily have similar distribution compared to the distribution of real data (target domain). To handle the domain shift problem, some supervised domain adaptation (SDA) methods are proposed recently. FADA [10] applied adversarial learning to learn embedding features that maximize the distance between two domains while aligning on a semantic level. CCSA [11] proposed a series of loss functions in order to manage the domain gap for a few-shot domain adaptation tasks. d-SNE [12] introduced a new approach that exploits the stochastic neighborhood embedding theory and modified-Hausdorff distance to improve the few-shot classification performance. Although, many efforts have been done on SDA or few-shot COVID-19 diagnosis areas [13, 14] , applying domain adaptation on CT images for the COVID-19 diagnostic task is relatively a new area, and our proposed method is one of the first attempts in utilizing synthetic chest CT scans for fewshot COVID-19 diagnostic task. In this paper, we propose a novel supervised domain adaptation based few-shot COVID-19 diagnostic method applied to CT scans. The proposed method consists of a Siamese structure, and the domain shift problem is solved with a cross-domain training mechanism. The main idea is to learn a model that can quantify three distances at domain-level: (a) Classification loss L c to maximize the distribution distance between samples from different categories; (b) Cross-domain pairing loss L cp to minimize the distribution distance between samples from a different domain but of the same category; (c) Cross-domain detaching loss L cd to maximize the distribution distance between samples from different domains and categories. An illustration of the domain shift problem and the proposed solution is shown in Figure 1 . Figure 1 (A) depicts the situation in which a COVID-19 CT based diagnostic model is pre-trained by a large amount of synthetic data and tested on real data. The model in Figure 1 (B), on the other hand, was given a large quantity of real data for training and was tested the same way as in the model in Figure 1 (A). As would be expected, Figure 1 (A) model performs poorly due to the domain shift problem compared to Figure 1 (B) model. Figure 1 (C) shows the proposed method which utilizes the same synthetic data as in model (A) plus few real data to reduce the domain gap between synthetic data and real data so that it can achieve a similar performance level as in (B), as synthetic data can easily be generated from our previous work [15] . The main contributions of our work are as follows: (1) We propose a novel chest CT data based COVID-19 diagnostic method designed for few-shot conditions in which only a small quantity of COVID-19 CT data is available. To the best of our knowledge, the proposed method stands the first domain adaptation method that utilizes synthetic COVID-19 CT data for a few-shot COVID-19 diagnostic task. (2) We propose a Siamese network structure that is trained by a novel cross-domain training mechanism. This cross-domain training mechanism enables an effective domain transfer via three different losses (L c , L cp and L cd ) in few-shot condition. Weight sharing Overview of the proposed method. The proposed model mainly consists of three parts: the source branch f (·) depicted in blue color, the target branch f (·) in orange color, and the prediction branch g(·) in gray color. The crossdomain losses (L cc and L ci ) and the classification loss L c are derived through the green arrow and the blue arrow, respectively. In this work, we propose a novel Siamese network based model for a few-shot COVID-19 CT diagnostic task as illustrated in Figure 2 . The Siamese network structure is basically formed in three components: source branch f (·), target branch f (·) and prediction branch g(·). Source and target branches have the same network structure which consists of a feature extractor and two fully-connected (FC) layers. The prediction branch is a network that contains three FC layers. During the training stage, weight sharing occurs between source and target branches as they take staggered input of synthetic and real as (X s1 , X t1 , X s2 , X t2 , ......, X sN , X tN , ). Since the synthetic data outnumber the real data in a large proportion, the real data were reused. Two embedded feature vectors (f s (X s ) and f t (X t )) are extracted through the two branches. Only f s (X s ) is passed to the prediction branch for calculating classification loss L c while both f s (X s ) and f t (X t ) are used to compute the cross-domain losses (L cp and L cd ). The classification loss and the cross-domain loss are used together to construct the overall loss for updating the network. During the test stage, a real CT image is passed through the network, and the network makes a binary diagnostic decision. In order to train the proposed classifier to classify an input CT scan to be positive or negative, we propose a classification loss L c as follows: where x denotes the input CT scan, y denotes the binary label (0, 1) of the corresponding input. The binary category cross entropy loss learns the difference between positive case and negative case and teaches the network to recognize the characteristics of the lesion representation associated with COVID-19. Since the data distributions of source and target domains are different, this domain gap can influence the diagnostic performance of the model when the network is pre-trained in the source domain but is tested in the target domain. Therefore, the classification loss alone is not sufficient to handle the domain shift problem, and further cross-domain measures are required to deal with the domain shift problem. We define here a novel Cross-domain pairing loss for managing the distance between features from different domains but have the same label. The cross-domain pairing loss is defined as , p(f (X p t )))) + D(p(f (X n s ), p(f (X n t )))) (2) where X p s represents a sample from the source domain with a positive label, while X n t denotes a sample from the target domain with a negative label. D is the distance between two probability distributions, p(·), and it is computed by average pairwise Euclidean distances between points of the same label from the two domains. By applying the cross-domain pairing loss L cp , the model can learn the pair-wise (same category) relationship between two domains by minimizing the distance between the two feature distributions. In order to further enhance the cross-domain diagnostic performance, we propose a cross-domain detaching loss, aimed at maximizing the distance between two feature distributions of different classes. The definition of cross-domain detaching loss is defined as follows: , p(f (X n t )))) + D(p(f (X n s ), p(f (X p t )))) Similar to L cp , the cross-domain detaching L cd uses Euclidean distance to manage the difference between the two feature distributions. The learning object is to maximize L cd so that the diagnostic model is able to effectively separate the distributions well at the feature-level for enhanced performance. The overall learning object is defined as where hypo-parameter α denotes a weight factor of the crossdomain losses. By applying both the classification loss L c and the cross-domain losses L cp , L cd , our proposed COVID-19 diagnostic model can not only effectively classify the positive/negative cases within the domain, but also can transfer the knowledge from the source domain to the target domain. Thus, the proposed combination of the loss functions fully exploits the large number of synthetic data for COVID-19 CT diagnostic task when only a small number of real data are given. Dataset. We constructed our dataset by using the data from both the source and the target domains. The source domain data is generated by our previous work [15] , and the target domain data comes from a public COVID-19 CT dataset which contains 29 individual cases [16] . Here, all the CT slices are divided into training set (20) and test set (9) by patient level. Specifically, we apply a combination of 6,000 source domain slices (synthetic data) and 60 target domain slices (real data) to form the training set, and we use 600 real CT scans as our test set. The COVID-19 diagnostic task is formulated as a binary classification task here, therefore, there are only two possible categories: positive and negative. In order to evaluate the proposed model, we randomly select n positive cases and n negative cases from the target domain slices and pair them with a randomly selected source domain group which contains 600 samples, so we can obtain 2n · 600 source-target pairs for the n-shot learning task. Evaluation metrics. We report the diagnostic performance by two metrics: accuracy and F1 score. We randomly resample 10 times for building 10 individual training sets, and report the results with the format as MEAN±95% CONFI-DENCE INTERVAL among the 10 folds. Experimental details. All CT scans are transformed to gray images on a Hounsfield unit (HU) scale [-600,1500] and resized to 512 × 512. The learning rate is set as 0.001 with a decay rate 0.95. The weight factor α is 0.25. In this sub-section, we focus on comparing the performance between the proposed method and other state-of-the-art supervised domain adaptation methods, including CCSA [11] , FADA [10] , d-SNE [12] . In order to show how does our domain adaptation help to improve the cross-domain diagnostic performance, we also involve a source only competitor which is trained on a deep network with only source domain data and tested on a target domain test set. This experiment is a 5-shot learning task and we used the Xception [17] network as the feature extractor. As shown in Table 1 , the proposed method outperforms the other state-of-the-art supervised domain adaptation approaches on both accuracy and F1 score metrics. We discuss the ablation study focused on examining the effectiveness of the components of our proposed loss functions in this sub-section. Our overall loss consists of three terms: classification loss L c , cross-domain pairing loss L cp and cross-domain L cd . In order to explore the performance contribution of each loss term, we evaluated the proposed model under four conditions: L c + L cp , L c + L cd and L c + L cp + L cd (ours). Experiment results are summarized in Table 2 . By comparing the contribution of each cross-domain loss term, it is clear that both the cross-domain paring and cross-domain detaching losses can help to overcome the domain gap and improve cross-domain diagnostic performance. We evaluate the proposed method in terms of its capability in a few-shot learning problem. We consider a total of five cases with n=1, 3, 5, 7 and 9, where n represents the shot number. Experiment results are shown in Figure 3 . From Figure 3 , it is clear that the proposed method can effectively handle few-shot diagnoses under diverse n-shot conditions. As expected, the diagnostic performance improves as the shot number n increases. It should be noted, however, that even in the extreme case of n = 1, the performance was maintained at above 0.6. In this paper, we proposed a supervised domain adaptation based few-shot COVID-19 diagnostic method for CT scans. The novelty of the proposed method consists of constructing a cross-domain training architecture by integrating a Siamese network and introducing two cross-domain training losses in addition to a classification loss. Siamese network based architecture and the proposed cross-domain losses have been demonstrated to be effective in handling the domain shift problem between the source and the target domains. Experimental results on the public COVID-19 CT dataset show that the proposed method outperforms the other state-ofthe-art supervised domain adaptation methods on a few-shot COVID-19 CT diagnostic task. For the future plan, we would like to pay attention to channel attention mechanism based COVID-19 diagnostic method using 3D CT volume. Coronavirus disease (covid-19) pandemic Johns hopkins coronavirus resource center Coronavirus: covid-19 has killed more people than sars and mers combined, despite lower case fatality rate Efficient multi-scale 3d cnn with fully connected crf for accurate brain lesion segmentation Deep learning segmentation of major vessels in x-ray coronary angiography Infnet: Automatic covid-19 lung infection segmentation from ct images Artificial intelligence distinguishes covid-19 from community acquired pneumonia on chest ct Diagnosis of coronavirus disease 2019 (covid-19) with structured latent multi-view representation learning Correlation of chest ct and rt-pcr testing in coronavirus disease 2019 (covid-19) in china: a report of 1014 cases Few-shot adversarial domain adaptation Unified deep supervised domain adaptation and generalization d-sne: Domain adaptation using stochastic neighborhood embedding Momentum contrastive learning for few-shot covid-19 diagnosis from chest ct images Single-shot lightweight model for the detection of lesions and the prediction of covid-19 from chest ct scans Covid-19 ct image synthesis with a conditional generative adversarial network Covid-19 ct segmentation dataset Xception: Deep learning with depthwise separable convolutions