key: cord-0047436-4y3a67vn authors: Ma, Tao; Tan, Ying title: Adaptive and Dynamic Knowledge Transfer in Multi-task Learning with Attention Networks date: 2020-07-11 journal: Data Mining and Big Data DOI: 10.1007/978-981-15-7205-0_1 sha: 1d581017ab79e4e1c506af97d20ae1397cb1cebd doc_id: 47436 cord_uid: 4y3a67vn Multi-task learning has shown promising results in many applications of machine learning: given several related tasks, it aims to generalize better on the original tasks, by leveraging the knowledge among tasks. The knowledge transfer mainly depends on task relationships. Most of existing multi-task learning methods guide learning processes based on predefined task relationships. However, the associated relationships have not been fully exploited in these methods. Replacing predefined task relationships with the adaptively learned ones may lead to superior performance as it can avoid the misguiding of improper pre-definition. Therefore, in this paper, we propose Task Relation Attention Networks to adaptively model the task relationships and dynamically control the positive and negative knowledge transfer for different samples in multi-task learning. To evaluate the effectiveness of the proposed method, experiments on various datasets are conducted. The experimental results demonstrate that the proposed method outperforms both classical and state-of-the-art multi-task learning baselines. Multi-task learning (MTL) aims to generalize better on the original tasks, by leveraging the shared knowledge among tasks [15] . In MTL, positive knowledge transfer leads to improved performance, because: 1) more related information is incorporated into the target tasks, which benefits the feature learning process to obtain better feature representations; 2) the incorporated information of positively related tasks acts as regularizer to avoid the risks of over-fitting. The knowledge transformer mainly depends on the task relationships. Therefore, how to appropriately model task relationships and how to control the knowledge transfer among tasks are crucial in MTL. With the recent advances in deep learning, MTL with deep neural networks has been used widely and successfully across many applications of machine learning, from natural language processing [7] to computer vision [11] . In most of these existing methods, the task relationships guiding the learning process are generally predefined, and the knowledge transfer among tasks relies on the sharing of hidden layers. Despite they have achieved promising results, but there are still several challenges for further improving the performance of MTL methods. The first challenge is adaptively learning the task relationships instead of relying on predefined relationships. For some multi-task learning problems with complex task associations, the pre-definition based on limited human knowledge requires costly human efforts. Besides, if there are not adequate efforts for sophisticated pre-definitions, there is likely to be negative knowledge transfer because of the misguiding of improperly predefined task relationships [8, 16] . Since an improper pre-definition of task relationships may result in negative knowledge transfer, it is essential for MTL methods to adaptively and appropriately model the task relationships with learning-based modules. The second challenge is controlling the knowledge transfer with the dynamically learned task relationships instead of the fixed ones. The relationships among tasks are not constantly fixed, but vary slightly from different samples. However, in most of the methods relying on pre-definition, the task relationships are fixed [9] , even in some MTL methods equipped with adaptively learning modules [14] . A concrete example is, in the works of [21] , the task relationships are determined by the inputs X and outputs Y of training data. In the testing phase, the model directly performs predictions on the learned relationships and the inputs X test , i.e., the relationships are fixed for the testing data. However, the task relationships in different samples may not be necessarily consistent. Therefore, dynamically modeling these relationships based on different inputs may lead to superior performance. To address these challenges, in this paper, we propose Task Relation Attention Networks (TRAN) to adaptively capture the task relationships and dynamically control the knowledge transfer in MTL. Specifically, TRAN is an attentionbased model to adaptively capture the task relationships via task correlation matrix according to their inputs and specify the shared feature representations for different tasks. Since the task relationships are adaptively learned by TRAN during the learning process, it is replacing the predefined relationships. And TRAN relies on the inputs, therefore, it can dynamically model the correlations for different samples. To evaluate the effectiveness of the proposed method, the experiments on various datasets are conducted. The experimental results demonstrate the proposed method can outperform both classical and state-of-the-art MTL baselines. The contributions of this paper can be summarized as follows: -This paper proposes Task Relation Attention Networks (TRAN) to adaptively learn the task relationships to replace the predefined ones. -The proposed method can dynamically control the knowledge transfer in multi-task learning based on the adaptively learned task relationships. -This paper provides an explicable learning-based framework for multi-task learning to learn the shared feature representations for different tasks. Multi-task Learning (MTL) provides an efficient framework for leveraging task relationships to achieve knowledge transfer and leads to improved performance. The typical MTL architectures were sharing the bottom layers for all tasks and split top layers for different tasks, proposed by [3] . Afterwards, there have been some attempts to design partly shared architectures between tasks, instead of only sharing the bottom hidden layers. For leveraging the task relationships in MTL, there are some recent examples. The cross-stitch networks [14] learned an optimal combination of task-specific representations for each task using linear units. The tensor factorization model [20] generated the hidden layer parameters for each task. The multi-gate mixture-of-experts model [13] using gating mechanism to capture the task differences and implicitly model the task relationships. In the works of [21] , they applied attention networks to statically capture the task relationships. Compared to the typical MTL methods, these works performed better feature learning for shared and task-specific representations and achieved better performance. However, there are still some limitations: the method of [14] can hardly expand to a large number of tasks; the method of [20] relies on a specific application scenario; the method of [13] applies linear gates to implicitly model the task relationships and performs poor efficiency as the number of experts increases; the method of [21] captures the task relationships statically. The key idea of attention mechanism is mainly based on the human visual attention, which has been successfully applied in various applications, such as natural machine translation [12] and text summarization [1] . Graph attention networks [18] was proposed for feature learning of graphstructured data, based on the self-attention mechanism. [10] applied the selfattention networks for time series warping. [17] performed self-attention on sentence encoding models, called Transformer, dispensing with recurrence and convolutions entirely [17] . Based on the Transformer, a language representation model called BERT was proposed [5] . In this paper, we attempt to adaptively model the task relationships in MTL with self-attention networks. Based on the learned relationships, the knowledge transfer among tasks can be dynamically controlled and each task can obtain a better shared feature representation, leading the MTL method to better performance. Given a single task, which can be regression or classification, the formal definition is as follows: where X represents the inputs, Y represents the ground truth values,Ŷ is the predicted values, E(·) is the mathematical expectation and M(·) is the model. In MTL, assuming that there are K tasks, the problem can be described as follows: where MTL(·) is the multi-task learning method, X i , Y i ,Ŷ i are respectively the inputs, labels and predictions of each task. In some real-world scenarios, the inputs of different tasks can be the same, i.e., X 1 , X 2 , ..., X K = X s . We apply attention networks to model the task relationships to help shared feature learning, called Task Relation Attention Networks (TRAN). Given K tasks, the inputs are a set of task features, is the dimensionality of features. Given task i and j, there is a shared attention network a t measuring the attention correlations e ij between two tasks, processed as follows: where W i ∈ R n×d and W j ∈ R n×d represent the encoding networks, modeling the original inputs into high-level latent representations with the dimensionality of d, and || is the concatenation operation. The attention weights for task i are normalized by softmax function to obtain the associated relationships between other tasks and task i, (α i1 , α i2 , ..., α iK ), processed as: The attention networks a t are implemented with fully-connected neural networks with the activation function of LeakyReLU. And the learning process of attention networks can be described as: The learned attention weights for target task i reflect the correlations between other tasks and task i. And all attention weights compose the task correlation matrix A. The task-specific representations for task i is s i , the combination of all task latent representations with their attention weights for task i, as follows: We perform multi-head attention mechanism on TRAN, which allows H independent attention networks to learn the attention weights in parallel and applies linear transformation W H = (w 1 , . . . , w H ) to combine them. The final task-specific representation for task i is processed as follows: The illustration of feature learning is presented in Fig. 1 . For multi-task prediction, each task is equipped with a feed-forward sub-layer to convert the final task-specific representations to the predicted values. Each feedforward sub-layer consists of two layers: the first one W e i is a fully-connected neural network with ReLU activation and skip-connection for embedding the final representations; the second one W p i is a linear transformation for prediction. The formal equation is described as: -For classification tasks, For multi-task learning, all tasks are jointly trained by optimizing a joint loss function L joint . Given the inputs X = {X 1 , X 2 , . . . , X K } and labels Y = {Y 1 , Y 2 , . . . , Y K }, the joint loss function is defined as: where the first item is the combination of the task-specific losses L j with their loss weights λ j ; the second item is the regularization for all the trainable parameters W ; the third item is the regularization for the learned attention correlation matrix A to ensure the auto-correlations of tasks. For each task, the task-specific loss is defined as: -Mean squared error (MSE) for regression tasks, -Cross entropy for classification tasks, The performance of the proposed method is evaluated on three datasets: Censusincome dataset, FashionMnist dataset and Sarcos dataset. -Census-income dataset: The Census-income dataset is from UCI Machine Learning Repository [2] . It is extracted from the 1994 Census database, which contains 299,285 instances of demographic information for American adults. We construct two multi-task learning problems based on 40 features. • Task 1: predict whether the income exceeds $50K; Task 2: predict whether this person is never married. • Task 1: predict whether the education level of this person is at least college; Task 2: predict whether this person is never married. -FashionMnist dataset: The samples in FashionMnist are 28 × 28 grayscale images with 10-class labels [19] , similar to Mnist. We construct a multi-task learning problem: Task 1 is the original 10-class classification task; Task 2 is predict if the objects are shoes, or female products, or another type. All task shares the same inputs. -Sarcos dataset: This is a regression dataset [4] where the goal is to predict the torque measured at each joint of a 7 degrees-of-freedom robotic arm, given the current state, velocity, and acceleration measured at each joint (7 torques for 21-dimensional inputs). Following the procedure of [4] , we have 7 regression tasks, where each task is to predict one torque. The baseline methods to be compared with are as follows: -LASSO: This is the classic linear method, learning each task independently with L1-norm regularization. share the bottom hidden layers and have top sub-layers for prediction. In this method, the task relationships are predefined and fixed. -L2-Constrained MTL: This is a classical MTL method [6] , where the parameters of different tasks are shared softly by an L2-constraint. Given two tasks, the prediction of each task can be described as: where θ 1 , θ 2 are the parameters of each task. And the objective function of multi-task learning is: where α is a hyper-parameter. This method models the task relationships with the magnitude of α. -Cross-stitch Networks (CSN): This is a deep learning based MTL method [14] . The knowledge is shared between tasks by a linear units, call cross-stitch. Given two tasks, h 1 and h 2 are the outputs of previous hidden layers of each task, and the outputs of cross-stitch are described as: where α ij , i, j = 1, 2 are trainable linear parameters representing the knowledge transfer. -Multi-gate Mixture-of-Experts (MMoE): It adopts the multi-gate mixture-of-experts structure to MTL [13] . In this method, there is a group of expert networks as the bottom layers, and the top task-specific layers can utilize the experts differently with gating mechanism. -Multiple Relational Attention Networks (MRAN): It was recently proposed [21] , applying attention networks to model multiple types of relationships in MTL. However, the task relationships in this method are statically modeled. The relationships are determined in the training phase, and in the testing phase, the relationships are fixed. Our method is able to dynamically capture the task relationships, and we will discuss about the differences between this method and ours. The proposed method is Task Relation Attention Networks (TRAN), and MH-TRAN means it is equipped with multi-head mechanism. Note that, because the code and some datasets of baseline MRAN are not released yet, we only compare its performance in the available Sarcos dataset, which is reported in their paper. Overall Comparison. The performance comparison on Census-income, Fash-ionMnist and Sarcos datasets are presented in Table 1 Task Relationships and Knowledge Transfer. We illustrate the task correlations learned by TRAN in Fig. 2 . In overall, all tasks are strongly correlated to themselves. And for different target tasks, the contributions of the other tasks vary a lot, e.g., the relationships in Sarcos dataset in Fig. 2(c) . We can observe the differences between traditional methods and TRAN. In traditional methods, the task correlations are predefined and equal for each task, however, TRAN captures their differences. In Sarcos dataset, for task 7, the contribution of task 1 is apparently less than the others. If the method relies on the pre-definition of equal task correlations, there may exist negative knowledge transfer hurting the performance. From the performance comparison, we can observe that TRAN outperforms the traditional methods with pre-definition, which demonstrates the effectiveness of adaptively capturing the task correlations. For Census-income dataset, we have two multi-task learning problems, and marital task appears in both group I and II accompanied with different tasks. As the performance comparison in Table 1 , marital task of TRAN in group II performs better than the one in group I. And from the illustration, we can observe that the task correlations in group II are stronger than the correlations in group I. This indicates there are more positive knowledge transfer in group II, which contributes to the improved performance. In order to verify our observation, we assess the practical strengths of task relationships in group I and II, because in general, stronger task correlations imply that there are more positive knowledge transfer. According to the works of [13] , the Pearson correlations of the labels of different tasks can be used as the quantitative indicator of task relationships, because the Pearson correlations of labels are positively correlated to the strength of task relationships. The Pearson correlation in group I is 0.1784, and the one in group II is 0.2396. This indicates that there is supposed to be more positive knowledge transfer in group II, corresponding to our observation. This demonstrates that TRAN does capture the practical task correlations and control the positive knowledge transfer to help improve the performance. Dynamically Control the Knowledge Transfer. The task relationships are not fixed, but vary slightly from different samples. We aim to dynamically capture the task relationships from different samples using TRAN. We randomly select 8 samples from the testing samples of Sacros dataset, and provide an illustration of their dynamically learned task correlations in Fig. 3 . From the correlations, we can observe that there is a slight variety in the task relationships in different samples. This demonstrate that TRAN does capture the dynamic task relationships, and the performance comparison in Table 1, 2 and 3 indicates TRAN controls the knowledge transfer to improve the performance using the dynamically learned relationships. . 3 . Illustration of the dynamic task relationships on the Sarcos dataset. We randomly select 8 samples from the testing dataset and visualize their attention correlation matrices. In this paper, we propose Task Relation Attention Networks to adaptively capture the task relationships, replacing the pre-defined ones in traditional MTL methods. Based on the learned relationships, the positive and negative knowledge transfer can be dynamically balanced in different samples. As a result, a better task-specific representation is obtained and leads to improved performance. In addition, the learned correlation matrix presents the dynamic transfer pattern, making the MTL method more explicable. To evaluate its performance, we conduct experiments on various datasets, including regression and classification tasks. Both classical and state-of-the-art MTL methods are employed to provide benchmarks. The experimental results and analyses demonstrate the effectiveness of our method, and its advantages over other methods. A convolutional attention network for extreme summarization of source code UCI machine learning repository Multitask learning Consistent multitask learning with nonlinear output relations Bert: pre-training of deep bidirectional transformers for language understanding Low resource dependency parsing: crosslingual parameter sharing in a neural network parser Annual Meeting of the Association for Computational Linguistics and the 7th International Joint Conference on Natural Language Processing A joint many-task model: growing a neural network for multiple NLP tasks One model to learn them all Multi-task representation learning for travel time estimation Temporal transformer networks: joint learning of invariant and discriminative time warping Fully-adaptive feature sharing in multi-task networks with applications in person attribute classification Effective approaches to attention-based neural machine translation Modeling task relationships in multi-task learning with multi-gate mixture-of-experts Cross-stitch networks for multitask learning An overview of multi-task learning in deep neural networks Transfer learning Attention is all you need Graph attention networks Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms Deep multi-task representation learning: a tensor factorisation approach Multiple relational attention network for multi-task learning