key: cord-0546015-kcs3le51 authors: Jiang, Jue; Tyagi, Neelam; Tringale, Kathryn; Crane, Christopher; Veeraraghavan, Harini title: Self-supervised 3D anatomy segmentation using self-distilled masked image transformer (SMIT) date: 2022-05-20 journal: nan DOI: nan sha: 0f0c7ec7a2badf4f9b6f92442ca39cf86fecbce1 doc_id: 546015 cord_uid: kcs3le51 Vision transformers, with their ability to more efficiently model long-range context, have demonstrated impressive accuracy gains in several computer vision and medical image analysis tasks including segmentation. However, such methods need large labeled datasets for training, which is hard to obtain for medical image analysis. Self-supervised learning (SSL) has demonstrated success in medical image segmentation using convolutional networks. In this work, we developed a underline{s}elf-distillation learning with underline{m}asked underline{i}mage modeling method to perform SSL for vision underline{t}ransformers (SMIT) applied to 3D multi-organ segmentation from CT and MRI. Our contribution is a dense pixel-wise regression within masked patches called masked image prediction, which we combined with masked patch token distillation as pretext task to pre-train vision transformers. We show our approach is more accurate and requires fewer fine tuning datasets than other pretext tasks. Unlike prior medical image methods, which typically used image sets arising from disease sites and imaging modalities corresponding to the target tasks, we used 3,643 CT scans (602,708 images) arising from head and neck, lung, and kidney cancers as well as COVID-19 for pre-training and applied it to abdominal organs segmentation from MRI pancreatic cancer patients as well as publicly available 13 different abdominal organs segmentation from CT. Our method showed clear accuracy improvement (average DSC of 0.875 from MRI and 0.878 from CT) with reduced requirement for fine-tuning datasets over commonly used pretext tasks. Extensive comparisons against multiple current SSL methods were done. Code will be made available upon acceptance for publication. Vision transformers (ViT) [1] efficiently model long range contextual information using multi-head self attention mechanism, which makes them robust to occlusions, image noise, as well as domain and image contrast differences. ViTs have shown impressive accuracy gains over convolutional neural networks (CNN) in medical image segmentation [2, 3] . However, ViT training requires a large number of expert labeled training datasets that are not commonly available in medical image applications. Self-supervised learning (SSL) overcomes the requirement for large labeled training datasets by using large unlabeled datasets through pre-defined annotation free pretext tasks. The pretext tasks are based on modeling visual information contained in images and provide a surrogate supervision signal for feature learning [4, 5, 6] . Once pre-trained, the model can be re-purposed for a variety of tasks and require relatively few labeled sets for fine-tuning. The choice of pretext tasks is crucial to successfully mine useful image information using SSL. Pretext tasks in medical image applications typically focus on learning denoising autoencoders constructed with CNNs to recover images in the input space using corrupted versions of the original images [7, 8] . Various data augmentation strategies have been used to corrupt images, which include jigsaw puzzles [9, 10] , transformation of image contrast and local texture [7] , image rotations [5] , and masking of whole image slices [11] . Learning strategies include pseudo labels [8, 12, 13] and contrastive learning [9, 14, 15, 16] . However, CNNs are inherently limited in their capacity to model long-range context than transformers, which may reduce their robustness to imaging variations and contrast differences. Hence, we combined ViT with SSL using masked image modeling (MIM) and self-distillation of concurrently trained teacher and student networks. MIM has been successfully applied to transformers to capture local context while preserving global semantics in natural image analysis tasks [17, 18, 19, 20, 19, 21] . Knowledge distillation with concurrently trained teacher has also been used for medical image segmentation by leveraging different imaging modality datasets (CT and MRI) [22, 23] . Self-distillation on the other hand, uses different augmented views of the same image [24] and has been used with contrastive learning with convolutional encoders for medical image classification [13] . Self-distillation learning with MIM and using a pair of online teacher and a student transformer encoders have been used for natural image classification and segmentation [19, 24] . However, the pretext tasks focused only on extracting global image embedding as class tokens [CLS] [24] , which was improved with global and local patch token embeddings [19] . However, these methods ignored the dense pixel dependencies, which is essential for dense prediction tasks like segmentation. Hence, we introduced a masked image prediction (MIP) pretext task to predict pixel-wise intensities within masked patches combined with the local and global embedding distillation applied to medical image segmentation. Our contributions include: (i) SSL using MIM and self-distillation approach combining masked image prediction, masked patch token distillation, and global image token distillation for CT and MRI organs segmentation using transformers. (ii) a simple linear projection layer for medical image reconstruction to speed up pre-training, which we show is more accurate than multi-layer decoder. (iii) SSL pre-training using large 3,643 3D CTs arising from a variety of disease sites including head and neck, chest, and abdomen with different cancers (lung, naso/oropharynx, kidney) and COVID-19 applied to CT and MRI segmentation. (iv) Evaluation of various pretext tasks using transformer encoders related to fine tuning data size requirements and segmentation accuracy. Goal: Extract a universal representation of images for dense prediction tasks, given an unlabeled dataset of Q images. Approach: A visual tokenizer f s (θ s ) implemented as a transformer encoder is learned via self-distillation using MIM pretext tasks in order to convert an image x into image tokens {x i } N i=1 , N being the sequence length. Self distillation is performed by concurrently training an online teacher tokenizer model f t (θ t ) with the same network structure as f s (θ s ) serving as the student model. A global image token distillation (ITD) is performed as a pretext task to match the global tokens extracted by f t and f s as done previously [24] . MIM pretext tasks include masked image prediction (MIP) and masked patch token distillation (MPD). Suppose {u, v} are two augmented views of a 3D image x. N image patches are extracted from the images to create a sequence of image tokens [1] The image tokens are then corrupted by randomly masking image tokens based on a binary vector m = {m i } N i=1 ∈ {0, 1} with a probability p and then replacing with mask token [20] e [M ASK] such that asũ = m u withũ i = e [M ASK] at m i = 1 andũ i = u i at m i = 0. The second augmented view v is also corrupted but using a different mask vector instance m asṽ = m v. Dense pixel dependency modeling using MIP: MIP involves recovering the original image view u from corruptedũ, asû = h P red s (f s (ũ, θ s )), where h P red s decodes the visual tokens produced by a visual tokenizer f s (θ s ) into images (see Fig.1 ). MIP involves dense pixel regression of image intensities within masked patches using the context of unmasked patches. The MIP loss is computed as (dotted green arrow in Fig.1 ): is a linear projection with one layer for dense pixel regression. A symmetrized loss using v andṽ is combined to compute the total loss for L M IP . Masked patch token self-distillation (MPD): MPD is accomplished by optimizing a teacher f t (θ t ) and a student visual tokenizer f s (θ s ) such that the student network predicts the tokens of the teacher network. The student network f s tokenizes the corrupted version of an imageũ to generate visual tokens The teacher network f t tokenizes the uncorrupted version of the same image u to generate visual tokens φ = {φ i } N i=1 . Similar to MIP, MPD is only concerned with ensuring prediction of the masked patch tokens. Therefore, the loss is computed from masked portions (i.e. m i =1) using cross-entropy of the predicted patch tokens (dotted red arrow in Fig.1 ): where P P atch s and P P atch t are the patch token distributions for student and teacher networks. They are computed by applying softmax to the outputs of h P atch s and h P atch t . The sharpness of the token distribution is controlled using a temperature term τ s > 0 and τ t > 0 for the student and teacher networks, respectively. Mathematically, such a sharpening can expressed as (using notation for the student network parameters) as: A symmetrized cross entropy loss corresponding to the other view v andṽ is also computed and averaged to compute the total loss for MPD. Sharpening transforms are applied to P [CLS] t and P [CLS] s similar to Equation 4. A symmetrized cross entropy loss corresponding to the corrupted viewṽ and another u is also computed and averaged to compute the total loss for L IT D . Online teacher network update: Teacher network parameters were updated using exponential moving average (EMA) with momentum update, and shown to be feasible for SSL [24, 19] as: θ t = λ m θ t + (1 − λ m )θ s , where λ m is momentum, which was updated using a cosine schedule from 0.996 to 1 during training. The total loss was, L total = L M IP + λ M P D L M P D + λ IT D L IT D . Implementation details: All the networks were implemented using the Pytorch library and trained on 4 Nvidia GTX V100. SSL optimization was done using ADAMw with a cosine learning rate scheduler trained for 400 epochs with an initial learning rate of 0.0002 and warmup for 30 epochs. λ M P D =0.1, λ IT D =0.1 were set experimentally. A default mask ratio of 0.7 was used. Centering and sharpening operations reduced chances of degenerate solutions [24] . τ s was [25] with 768 embedding, window size of 4 × 4 × 4, patch size of 2 was used. The 1-layer decoder was implemented with a linear projection layer with the same number of output channels as input image size. The network had 28.19M parameters. Following pre-training, only the student network was retained for fine-tuning and testing. Training dataset: SSL pre-training was performed using 3,643 CT patient scans containing 602,708 images. Images were sourced from patients with head and neck (N=837) and lung cancers (N=1455) from internal and external [26] , as well as those with kidney cancers [27] (N=710), and COVID-19 [28] (N=650). GPU limitation was addressed for training, fine-tuning, and testing by image resampling (1.5×1.5×2mm voxel size) and cropping (128×128×128) to enclose the body region. Augmented views for SSL training was produced through randomly cropped 96×96×96 volumes, which resulted in 6×6×6 image patch tokens. A sliding window strategy with half window overlap was used for testing [2, 3] . CT abdomen organ segmentation (Dataset I): The pre-trained networks were finetuned to generate volumetric segmentation of 13 different abdominal organs from contrast-enhanced CT (CECT) scans using publicly available beyond the cranial vault (BTCV) [32] dataset. Randomly selected 21 images are used for training and the remaining used for validation. Furthermore, blinded testing of 20 CECTs evaluated on the grand challenge website is also reported. MRI upper abdominal organs segmentation (Dataset II): The SSL network was evaluated for segmenting abdominal organs at risk for pancreatic cancer radiation treatment, which included stomach, small and large bowel, liver, and kidneys. No MRI or pancreatic cancer scans were used for SSL pre-training. Ninety two 3D T2-weighted MRIs (TR/TE = 1300/87 ms, voxel size of 1×1×2 mm 3 , FOV of 400×450×250 mm 3 ) and acquired with pnuematic compression belt to suppress breathing motion were analyzed. Fine tuning used five-fold crossvalidation and results from the validation folds not used in training are reported. Experimental comparisons: SMIT was compared against representative SSL medical image analysis methods. Results from representative published methods on the BTCV testing set [30, 2, 3] were chosen to evaluate the impact of the pretext task on segmentation accuracy and included (a) local texture and semantics modeling using model genesis [7] , (b) jigsaw puzzles [10] , (c) contrastive learning [16] with (a),(b), (c) implemented on CNN backbone, (d) self-distillation using whole image reconstruction [24] , (e) masked patch reconstruction [18] without self-distillation, (f) MIM using selfdistillation [19] with (d),(e), and (f) implemented in a SWIN transformer backbone. Random initialization results are shown for benchmarking purposes using both CNN and SWIN backbones. Identical training and testing sets were used with hyper-parameter adopted from their default implementation. CT segmentation accuracy: As shown in Table. 1, our method SMIT outperformed representative published methods including transformer-based segmentation [31, 3, 2] . SMIT was also more accurate than all evaluated SSL methods (Table. 2) for most organs. Prior-guided contrast learning (PRCL) [16] was more accurate than SMIT for gall bladder (0.797 vs. 0.787). SMIT was more accurate than self-distillation with MIM [19] (average DSC of 0.848 vs. 0.833) as well as masked image reconstruction without distillation [18] (0.848 vs. 0.830). Fig.2 shows a representative case with multiple organs segmentations produced by the various methods. SMIT was the most accurate method including for organs with highly variable appearance and size such as the stomach and esophagus. MRI segmentation accuracy: SMIT was more accurate than all other SSLbased methods for all evaluated organs, including stomach and bowels, which depict highly variable appearance and sizes (Table. 2). SMIT was least accurate for small bowel compared to other organs, albeit this accuracy for small bowel was higher than all other methods. Fig.2 shows a representative case with multiple organs segmentations produced by the various methods. Ablation experiments: All ablation and design experiments (1layer decoder vs. multi-layer or ML decoder) were performed using the BTCV dataset and used the SWIN-backbone as used for SMIT. ML decoder was implemented with five transpose convolution layers for up-sampling back to the input image resolution. Fig.4 shows the accuracy comparisons of networks pre-trained with different tasks including full image reconstruction, contrastive losses, pseudo labels [33] , and various combination of the losses (L M IP , L M P D , L IT D ). As shown, the accuracies for all the methods was similar for large organs depicting good contrast that include liver, spleen, left and right kidney ( Fig.4(I) ). On the other hand, organs with low soft tissue contrast and high variability (Fig.4 (II)) and small organs ( Fig.4(III) ) show larger differences in accuracies between methods with SMIT achieving more accurate segmentations. Major blood vessels Fig.4 (IV) also depict segmentation accuracy differences across methods, albeit less so than for small organs and those with low soft-tissue contrast. Importantly, both full image reconstruction and multi-layer decoder based MIP (ML-MIP) were less accurate than SMIT, which uses masked image prediction with 1-layer linear projection decoder ( Fig.4 (II,III,IV)). MPD was the least accurate for organs with low softtissue contrast and high variability ( Fig.4 (II)), which was improved slightly by adding global image distillation (ITD). MIP alone (using 1-layer decoder) was similarly accurate as SMIT and more accurate than other pretext task based segmentation including ITD [24] , MPD+ITD [19] . Impact of pretext tasks on sample size for fine tuning: SMIT was more accurate than all other SSL methods irrespective of sample size used for finetuning ( Fig.3(a) ) and achieved faster convergence (Fig.3(c) ). It outperformed iBot [19] , which uses MPD and ITD, indicating effectiveness of MIP for SSL. Impact of mask ratio on accuracy: Fig.3 initially increased accuracy and then stabilized. Image reconstruction error increased slightly with increasing masking ratio. Fig.5 shows a representative CT and MRI reconstruction produced using default and multi-layer decoder, wherein our method was more accurate even in highly textured portions of the images containing multiple organs (additional examples are shown in Supplementary Fig 1) . Quantitative comparisons showed our method was more accurate (MSE of 0.061 vs. 0.32) for CT (N=10 cases) and 92 MRI (MSE of 0.062 vs. 0.34) than multi-layer decoder. In this work, we demonstrated the potential for SSL with 3D transformers for medical image segmentation. Our approach, which leverages CT volumes arising from highly disparate body locations and diseases showed feasibility to produce robustly accurate segmentations from CT and MRI scans and surpassed multiple current SSL-based methods, especially for hard to segment organs with high appearance variability and small sizes. Our introduced masked image dense prediction pretext task improved the ability of self distillation using MIM to segment a variety of organs from CT and MRI and with lower requirement of fine tuning dataset size. Our method shows feasibility for medical image segmentation. An image is worth 16x16 words: Transformers for image recognition at scale Cotr: Efficiently bridging cnn and transformer for 3d medical image segmentation Unetr: Transformers for 3d medical image segmentation Unsupervised learning of visual representations by solving jigsaw puzzles Unsupervised representation learning by predicting image rotations Momentum contrast for unsupervised visual representation learning Models genesis Learning semantics-enriched representation via self-discovery, self-classification, and selfrestoration Lippert, C.: 3d self-supervised methods for medical imaging Rubik's cube+: A selfsupervised feature learning framework for 3d medical image analysis Medical transformer: Universal brain encoder for 3d mri analysis Selfsupervised learning for medical image analysis using image context restoration Unsupervised representation learning meets pseudo-label supervised self-distillation: A new approach to rare disease classification Contrastive learning of global and local features for medical image segmentation with limited annotations Parts2whole: Self-supervised contrastive learning via reconstruction Preservational learning improves self-supervised medical image models by reconstructing diverse contexts MST: Masked self-supervised transformer for visual representation SimMIM: A simple framework for masked image modeling Image BERT pre-training with online tokenizer BEiT: BERT pre-training of image transformers Masked autoencoders are scalable vision learners Towards cross-modality medical image segmentation with online mutual knowledge distillation Unpaired cross-modality educed distillation (cmedl) for medical image segmentation Emerging properties in self-supervised vision transformers Swin transformer: Hierarchical vision transformer using shifted windows Data from NSCLC-radiomics. The Cancer Imaging Archive Radiology data from the cancer genome atlas kidney renal clear cell carcinoma [tcga-kirc] collection. The Cancer Imaging Archive Artificial intelligence for the detection of covid-19 pneumonia on chest ct using multinational datasets Encoder-decoder with atrous separable convolution for semantic image segmentation nnu-net: a self-configuring method for deep learning-based biomedical image segmentation Transunet: Transformers make strong encoders for medical image segmentation MICCAI multi-atlas labeling beyond the cranial vault-workshop and challenge An empirical study of training self-supervised vision transformers Revisiting rubik's cube: Selfsupervised learning with volume-wise transformation for 3d medical image segmentation