key: cord-0140580-id35gp12 authors: Motamed, Saman; Khalvati, Farzad title: Multi-class Generative Adversarial Nets for Semi-supervised Image Classification date: 2021-02-13 journal: nan DOI: nan sha: c1ab5a7758c23a67d11eb5579cf057484af2d82f doc_id: 140580 cord_uid: id35gp12 From generating never-before-seen images to domain adaptation, applications of Generative Adversarial Networks (GANs) spread wide in the domain of vision and graphics problems. With the remarkable ability of GANs in learning the distribution and generating images of a particular class, they can be used for semi-supervised classification tasks. However, the problem is that if two classes of images share similar characteristics, the GAN might learn to generalize and hinder the classification of the two classes. In this paper, we use various images from MNIST and Fashion-MNIST datasets to illustrate how similar images cause the GAN to generalize, leading to the poor classification of images. We propose a modification to the traditional training of GANs that allows for improved multi-class classification in similar classes of images in a semi-supervised learning framework. Generative Adversarial Networks [1] is one of the most exciting inventions in machine learning in the past decade. While applications of GANs spread wide in the field of computer vision, image classification using GANs is relatively unexplored. One of the early uses of GANs in image classification was detecting anomalies in images, first introduced by Schlegl et al. [2] to detect and identify anomalies in the form of retinal fluid or hyper-reflective foci in optical coherence tomography (OCT) images of the retina. By defining a variation score V (x) (eq. 2), their proposed Anomaly Detection GAN (AnoGAN) captured the characteristic and visual differences of two images; one generated by the GAN and one real image. The idea was to, for instance, train the GAN on only healthy images. When GAN is trained, the generator can generate images similar to those in the healthy image class. During the test phase, the variation score V (x) must be low if the test image is healthy and GAN's generator (G) can generate a similar image to that of the healthy image. If the test image is not healthy and contains anomalies, V (x) would be larger, and the generated image would look visually different than the real image containing anomalies. do not have labels and hence, do not use any of the unknown images to train our model. The first step of multi-class classification is to distinguish images of the unknown class from images of the known classes (e.g., COVID-19 positive from negative images). The second step is to classify the images that fall in the known images by using one of the GANs trained on either of the known class of images (e.g. normal or pneumonia). It is to be mentioned that we explore multi-class classification in a setting where we have two known classes (C1 and C2) of images and an unknown class (C3) of images. This approach can be extended to multiple known classes and an unknown class of images. We observed that, in some instances, training a GAN on images of a class C1 generated not only low variation scores for test images of the same class, but also low scores for test images of class C2, hindering the ability to classify C1 vs. C2. We hypothesized the reason to be the ability of the GAN's generator G, being trained on C1 images, generalizing to learn and generate images that visually look similar to C2 images. In this work, we carried out multiple experiments using different datasets to understand how visually similar images affect GAN-based image classification's performance. We propose MCGAN, a GAN-based multi-classs classifier, to overcome the challenge of classifying visually similar images using GANs. By using all labeled images in training the GANs, we force G to not generalize in a way that can generate similar images to images of other classes. A GAN is a deep learning model comprised of two main parts; Generator (G) and Discriminator (D). G can be seen as an art forger that tries to reproduce art-work and pass it as the original. D, on the other hand, acts as an art authentication expert that tries to tell apart real from forged art. Successful training of a GAN is a battle between G and D where if successful, G generates realistic images and D is not able to tell the difference between G's generated images compared to real images. G takes as input a random Gaussian noise vector and generates images through transposed convolution operations. D is trained to distinguish the real images (x) from generated fake images (G(z)). Optimization of D and G can be thought of as the following game of minimax [1] with the value function V (G, D): During training, G is trained to minimize D's ability to distinguish between real and generated images, while D is trying to maximize the probability of assigning "real" label to real training images and "fake" label to the generated images from G. The Generator improves at generating more realistic images while Discriminator gets better at correctly identifying between real and generated images. Today, when the term GAN is used, the Deep Convolution GAN (DCGAN) [4] is the architecture that it refers to. The goal of the proposed GAN-based multi-class (MCGAN) classifier is to distinguish three classes of data (C1, C2, C3) from one another, while only two classes (C1 and C2) have labels and the third class (C3) is unknown and has no labels. In doing so, first, we classify C1 and C2 vs. C3, and then classify C1 vs. C2. Since we have labels for only C1 and C2, the first step is to learn the distribution of these two classes using two GANs. Here, we describe the architecture for learning the distribution of C1 images and then classifying the images of C1 vs. C2. This can be repeated to learn the distribution of C2 images and then classifying the images of C2 vs. C1. To learn the distribution of C1 images, a traditional GAN's (DCGAN, AnoGAN, etc.) discriminator takes as input the generator's output (labeled Fake) and a real C1 image (labeled Real). This forces the generator to learn the distribution of the images from the real class. If the images of C1 and C2 shares similar characteristics, training the GAN on the images of C1 could cause G to learn and generalize well enough, leading to generating similar images to C2 and hence, hindering the classification of the two known classes (C1 vs. C2). To overcome this challenge, we feed a third input to the discriminator; images of C2. While these are real images from C2, we label them as Fake. This forces the generator not to learn to generalize to this similar class (C2) while learning the characteristics of C1. When G generates an image that could pass as belonging to C2, the discriminator flags it as a fake image, and G re-evaluates its learning at those stances. Figure 1 shows the architecture of Multi-class GAN (MCGAN). In imbalanced datasets where the number of images is not the same for both C1 and C2, we use batches of randomly selected images of C2 at each training iteration. The Variation score V (x) for the query image x, proposed by Schlegl et al. [2] , is defined as; Before calculating V(x) in test, a point z i has to be found through back-propagation that tries to generate an image as similar as possible to image x. The loss function used to find z i is based on residual and discriminator loss defined below. λ adjusts the weighted sum of the overall loss and variation score. We used λ = 0.2 to train our proposed MCGAN and AnoGAN [2] . Both architectures were trained with the same initial conditions for performance comparison. We used images from two different datasets. MNIST [5] dataset that contains 60,000 training images of handwritten digits and 10,000 test images. Fashion-MNIST [6] is a dataset of Zalando's article images-consisting of a training set of 60,000 examples and a test set of 10,000 examples. In experiments where test sets are not balanced, we randomly select the same number of images as the smaller test set for class a from the bigger test set of class b. All gray-scale images were resized to 64 × 64 pixels, with pixel intensities scaled to -1 to 1. To pick a subset of similar classes from MNIST and Fashion-MNIST (F-MNIST) datasets that could cause generalization in GANs, we used metric learning [7] . The goal of metric learning is to train models that can embed inputs into a high-dimensional space such that "similar" inputs are located close to each other. To bring images from the same class closer to each other via the embedding, the training data was constructed as randomly selected pairs of images from each class matched to the label of that class, instead of traditional (X,y) pairs where y is the label for corresponding X as singular images of each class. By embedding the images using a shallow three-layer CNN, we computed the similarity between the image pairs by calculating the cosine similarity of the embeddings. We used these similarities as logits for a softmax. This moves the pairs of images from the same class closer together. After the training was complete, we sampled 10 examples from each of the 10 classes, and considered their near neighbours as a form of prediction; that is, does the example and its near neighbours share the same class. This is visualized as a confusion matrix shown in figure 3 . The numbers that lie on the diagonal represent the correct classifications and the numbers off the diagonal represent the wrong labels that were misclassified as the true label. We intentionally used a shallow three-layer CNN to enforce some misclassification, as achieving near-perfect results in classifying datasets such as MNIST using CNNs is easy. Using the information from figure 3, we picked the class pairs (9, 4) and (8, 3) from the MNIST dataset and (Coat, Shirt), (Coat, Pullover), and (Boot, Sandal) from F-MNIST dataset. For semi-supervised multi-class classification of the pair of known images that we have labels for and an unknown class which we will introduce when testing our models, we trained two GANs. One GAN was trained to generate images similar to each class of known images. For instance, for the pair (9, 4), one DCGAN (AnoGAN and DCGAN have the same architecture) was trained on 9s and one was trained on 4s. Similarly, one MCGAN was trained on 9s, labeled Real while 4s were labeled Fake and one MCGAN was trained on 9s, labeled Fake while 4s were labeled Real. The models were trained using an NVIDIA GeForce RTX 2080 Ti with 11 GB of memory. For classification of images from known C1 and C2 and unknown C3 classes, first we classify the unknown images from known images by looking at the final variation score which is the sum of variation scores from each trained GAN (the one GAN trained on C1 and the one trained on C2). After classifying C3 images, we need to classify C1 and C2 images. For this, we have two options; 1) use the GAN trained on C1 images to classify the remaining images as C1 or C2 or 2) use the GAN trained on C2 images and perform classification. In order to pick the best instance for step 2, we look at the performance of supervised binary classification of the known classes C1 and C2. The GAN that performs better in classifying C1 and C2 is used for the second step of multi-class classification in the semi-supervised setting when we introduce images of class C3. Figure 4 shows this process. We calculated variation scores for both DCGANs and MCGANs. Lower variation scores would translate to the test image having more probability of belonging to the class of images the GAN was trained to generate images of, while a larger variation score decreased this probability. We calculated the area under the ROC curve (AUC) of each model. For semi-supervised multi-class classification, we added images from class 9 to (3, 8) , 8 to (4, 9) , Dress to (Coat, Pullover), Bag to (Boot, Sandal) and Pullover to (Coat / Shirt). First, we calculate the accuracy of classifying the unknown class from known classes, then, by using results from table 1, we picked the GAN that performed better in classifying the known pairs, and used that GAN to label the remaining test images that were not categorized as the unknown class -C3. Figure 4 shows the two step process of multi-class classification using the two trained GANs. In step 1, the combined variation scores from the two trained GANs on known classes scores the unknown class images higher than known classes. This allows for separation of known classes from the unknown class. In the second step, we can use either GAN (we picked the one with better performance according to table 1) to classify the known classes from one another. The semi-supervised classification using MCGAN outperformed DCGAN in classifying the unknown class in all but one experiment (Boot, Sandal -Bag) where the accuracy of detecting the unknown class (bag) reduced from 80% to 79% using MCGAN while the classification of the known classes improved from 77% to 81%. MCGAN enabled a better classification of the known classes (Tables 1, and2) while improving the classification of the unknown class in all but one experiment ( Table 2 ). In classification settings where we do not have enough labeled images for a class, semi-supervised modes of training that do not require images of that class to train are of value. While GANs can be used to classify images, we showed that in some settings where labeled images share similar characteristics, the generalization ability of GANs can hinder the performance of classification. Using images from MNIST and Fashion MNIST datasets, we showed how, for instance, a GAN trained to generate images of handwritten digit 8 can also generate images that are similar to digit 3. To use GANs in classifying 8s from 3s, this generalization would result in not only low variation scores for images of digit 8, but also for images of digit 3. We proposed MCGAN, which used both classes in training the GAN's discriminator. By labeling the digits 3 as fake, we guided the generator to not generate images that can identify as 3 while learning to generate images of class 8. This improved the multi-class classification of both the unknown class from known classes and the known classes from one another. MCGAN however, as shown in table 2 with classes (Coat, Shirt -Pullover) and (3, 8 -9) , although works better than DCGAN, still does not perform well in the task of multi-class classification. When the unknown class (Pullover, 9) shares similar characteristics with some / all of the known classes, while MCGAN shows improved results in classifying the classes compared to DCGAN, the classification performance suffers as a result. The goal for this study was not to achieve state of the art classification results on the two datasets, rather using a simple GAN architecture and showing how the proposed modification in training the discriminator can improve classification in settings where over-generalization is possible. With development of more complicated GAN architectures, such as RANDGAN [3] for detection of COVID-19 X-rays, this modification can further improve the accuracy of the models. In this work, we demonstrated how GANs could learn to generalize to different classes of images if they share similar characteristics with the class of training images. This generalization can hinder the ability of GANs for the task of image classification. We proposed using all labeled images in training the discriminator to penalize the generalization. The multi-class discriminator training showed improved accuracy of semi-supervised image classification. Generative adversarial nets Unsupervised anomaly detection with generative adversarial networks to guide marker discovery Randgan: Randomized generative adversarial network for detection of covid-19 in chest x-ray Unsupervised representation learning with deep convolutional generative adversarial networks MNIST handwritten digit database Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms Metric learning: A survey. Foundations and trends in machine learning This research was funded by Chair in Medical Imaging and Artificial Intelligence funding, a joint Hospital-University Chair between the University of Toronto, The Hospital for Sick Children, and the SickKids Foundation.