More Is More - Narrowing the Generalization Gap by Adding Classification Heads

Page created by Dawn Beck
 
CONTINUE READING
More Is More - Narrowing the Generalization Gap by Adding Classification Heads
More Is More - Narrowing the Generalization Gap by Adding
                                                                     Classification Heads
                                                                   Roee Cates                                            Daphna Weinshall
                                                        The Hebrew University of Jerusalem                       The Hebrew University of Jerusalem
                                                            roee.cates@mail.huji.ac.il                                     daphna@cs.huji.ac.il
arXiv:2102.04924v2 [cs.LG] 11 Feb 2021

                                                                  Abstract

                                             Overfit is a fundamental problem in machine learning in
                                         general, and in deep learning in particular. In order to re-
                                         duce overfit and improve generalization in the classification
                                         of images, some employ invariance to a group of transfor-
                                         mations, such as rotations and reflections. However, since
                                         not all objects exhibit necessarily the same invariance, it
                                         seems desirable to allow the network to learn the useful
                                         level of invariance from the data. To this end, motivated by
                                         self-supervision, we introduce an architecture enhancement
                                         for existing neural network models based on input transfor-            Figure 1: Illustration of the TransNet architecture, which consists
                                                                                                                of 2 heads associated with 2 transformations, the identity and ro-
                                         mations, termed ’TransNet’, together with a training algo-
                                                                                                                tation by 90◦ . Each head classifies images transformed by associ-
                                         rithm suitable for it. Our model can be employed during
                                                                                                                ated transformation, while both share the same backbone.
                                         training time only and then pruned for prediction, result-
                                         ing in an equivalent architecture to the base model. Thus
                                                                                                                put the penultimate layer of the base CNN - for each input
                                         pruned, we show that our model improves performance on
                                                                                                                transformation (see Fig. 1). The transformations associated
                                         various data-sets while exhibiting improved generalization,
                                                                                                                with the model’s heads are not restricted apriori.
                                         which is achieved in turn by enforcing soft invariance on the
                                         convolutional kernels of the last layer in the base model.                 The idea behind the proposed architecture is that each
                                         Theoretical analysis is provided to support the proposed               head can specialize in a different yet related classification
                                         method.                                                                task. We note that any CNN model can be viewed as a spe-
                                                                                                                cial case of the TransNet model, consisting of a single head
                                                                                                                associated with the identity transformation. The overall task
                                                                                                                is typically harder when training TransNet, as compared to
                                         1. Introduction                                                        the base CNN architecture. Yet by training multiple heads,
                                            Deep neural network models currently define the state               which share the convolutional backbone, we hope to reduce
                                         of the art in many computer vision tasks, as well as speech            the model’s overfit by providing a form of regularization.
                                         recognition and other areas. These expressive models are                   In Section 3 we define the basic model and the training
                                         able to model complicated input-output relations. At the               algorithm designed to train it (see Alg. 1). We then discuss
                                         same time, models of such large capacity are often prone               the type of transformations that can be useful when learning
                                         to overfit, i.e. performing significantly better on the training       to classify images. We also discuss the model’s variations:
                                         set as compared to the test set. This phenomenon is also               (i) pruned version that employs multiple heads during train-
                                         called the generalization gap.                                         ing and then keeps only the head associated with the identity
                                            We propose a method to narrow this generalization gap.              transformation for prediction; (ii) the full version where all
                                         Our model, which is called TransNet, is defined by a set               heads are used in both training and prediction.
                                         of input transformations. It augments an existing Convolu-                 Theoretical investigation of this model is provided in
                                         tional Neural Network (CNN) architecture by allocating a               Section 4, using the dihedral group of transformations (D4 )
                                         specific head - a fully-connected layer which receives as in-          that includes rotations by 90o and reflections. We first prove

                                                                                                            1
More Is More - Narrowing the Generalization Gap by Adding Classification Heads
that under certain mild assumptions, instead of applying              Self training algorithms are used for representation learn-
each dihedral transformation to the input, one can compile            ing, by training a deep network to solve pretext tasks where
it into the CNN model’s weights by applying the inverse               labels can be produced directly from the data. Such tasks
transformation to the convolutional kernels. In order to ob-          include colorization [32, 16], placing image patches in the
tain intuition about the inductive bias of the model’s training       right place [22, 7], inpainting [23] and orientation predic-
algorithm in complex realistic frameworks, we analyze the             tion [10]. Typically, self-supervision is used in unsuper-
model’s inductive bias using a simplified framework.                  vised learning [8], to impose some structure on the data, or
    In Section 5 we describe our empirical results. We first          in semi-supervised learning [31, 12]. Our work is motivated
introduce a novel invariance score (IS), designed to mea-             by RotNet, an orientation prediction method suggested by
sure the model’s kernel invariance under a given group of             [10]. It differs from [31, 12], as we allocate a specific clas-
transformations. IS effectively measures the inductive bias           sification head for each input transformation rather than pre-
imposed on the model’s weights by the training algorithm.             dicting the self-supervised label with a separate head.
To achieve a fair comparison, we compare a regular CNN
                                                                      Equivariant CNNs. Many computer vision algorithms are
model traditionally trained, to the same model trained like
                                                                      designed to exhibit some form of invariance to a transfor-
a TransNet model as follows: heads are added to the base
                                                                      mation of the input, including geometric transformations
model, it is trained as a TransNet model, and then the extra
                                                                      [20], transformations of time [28], or changes in pose and
heads are pruned. We then show that training as TransNet
                                                                      illumination [24]. Equivariance is a more relaxed property,
improves test accuracy as compared to the base model. This
                                                                      exploited for example by CNN models when translation is
improvement was achieved while keeping the optimized
                                                                      concerned. Work on CNN models that enforces strict equiv-
hyper-parameters of the base CNN model, suggesting that
                                                                      ariance includes [26, 9, 1, 21, 2, 5]. Like these methods, our
further improvement by fine tuning may be possible. We
                                                                      method seeks to achieve invariance by employing weight
demonstrate the increased invariance of the model’s kernels
                                                                      sharing of the convolution layers between multiple heads.
when trained with TransNet.
                                                                      But unlike these methods, the invariance constraint is soft.
Our Contribution                                                      Soft equivariance is also seen in works like [6], which em-
                                                                      ploys a convolutional layer that simultaneously feeds ro-
   • Introduce TransNet - a model inspired by self-                   tated and flipped versions of the original image to a CNN
     supervision for supervised learning that imposes par-            model, or [30] that appends rotation and reflection versions
     tial invariance to a group of transformations.                   of each convolutional kernel.
   • Introduce an invariance score (IS) for CNN convolu-
     tional kernels.                                                  3. TransNet
   • Theoretical investigation of the inductive bias implied          Notations and definitions Let X = {(xi , yi )}ni=1 denote
     by the TransNet training algorithm.                              the training data, where xi ∈ Rd denotes the i-th data point
   • Demonstrate empirically how both the full and pruned             and yi ∈ [K] its corresponding label. Let D denote the
     versions of TransNet improve accuracy.                           data distribution from which the samples are drawn. Let H
                                                                      denote the set of hypotheses, where hθ ∈ H is defined by its
2. Related Work                                                       parameters θ (often we use h = hθ to simplify notations).
                                                                      Let `(h, x, y) denote the loss of hypothesis h when given
                                                                      sample (x, y). The overall loss is:
Overfit. A fundamental and long-standing issue in machine
learning, overfit occurs when a learning algorithm mini-                            L(h, X) = E(x,y)∼D [`(h, x, y)]              (1)
mizes the train loss, but generalizes poorly to the unseen
test set. Many methods were developed to mitigate this                Our objective is to find the optimal hypothesis:
problem, including early stopping - when training is halted
as soon as the loss over a validation set starts to increase,                            h∗ := arg min L(h, X)                   (2)
                                                                                                  h∈H
and regularization - when a penalty term is added to the
optimization loss. Other related ideas, which achieve sim-               For simplicity, whenever the underlying distribution of a
ilar goals, include dropout [27], batch normalization [14],           random variable isn’t explicitly defined we use the uniform
transfer learning [25, 29], and data augmentation [3, 33].                                               P|A|
                                                                      distribution, e.g. Ea∈A [a] = 1/|A| i=1 a.
Self-Supervised Learning. A family of learning algo-
                                                                      3.1. Model architecture
rithms that train a model using self generated labels (e.g.
the orientation of an image), in order to exploit unlabeled              The TransNet architecture is defined by a set of input
data as well as extract more information from labeled data.           transformations T = {tj }m
                                                                                               j=1 , where each transformation

                                                                  2
t ∈ T operates on the inputs (t : Rd → Rd ) and is associ-              Algorithm 1: Training the TransNet model
ated with a corresponding model’s head. Thus each trans-                   input : TransNet model hT , batch size b,
formation operates on datapoint x as t(x), and the trans-                          maximum iterations num M AX IT ER
formed data-set is defined as:                                             output: trained TransNet model
                  t(X) := {(t(xi ), yi )}ni=1               (3)           1   for i = 1 . . . M AX IT ER do
                                                                                                                      iid
                                                                          2      sample a batch B = {(xk , yk )}bk=1 ∼ Db
   Given an existing NN model h, henceforth called the                    3      forward:
base model, we can split it to two components: all the lay-               4      for t ∈ T do
                                                                                                   Pb
ers except for the last one denoted f , and the last layer g as-          5          L(ht , B) = 1b k=1 `(ht , t(xk ), yk )
sumed to be a fully-connected layer. Thus h = g ◦ f . Next,               6      end
we enhance model h by replacing g with |T| = m heads,                                           1
                                                                                                  P
                                                                          7      LT (hT , B) = m    t∈T L(ht , B)
where each head is an independent fully connected layer gt                8      backward (SGD):
associated with a specific transformation t ∈ T. Formally,                9      update the model’s weights by differentiating
each head is defined by ht = gt ◦ f , and it operates on the                      the sampled loss LT (hT , B)
corresponding transformed input as ht (t(x)).                            10   end
   The full model, with its m heads, is denoted by hT :=
{ht }t∈T , and operates on the input as follows:

               hT (x) := Et∈T [ht (t(x))]                              3.3. Transformations
The corresponding loss of the full model is defined as:
                                                                       Which transformations should we use? Given a specific
              LT (hT , X) := Et∈T [L(ht , t(X))]            (4)        data-set, we distinguish between transformations that occur
                                                                       naturally in the data-set versus such transformations that
Note that the resulting model (see Fig. 1) essentially rep-            do not. For example, horizontal flip can naturally occur
resents m models, which share via f all the weights up to              in the CIFAR-10 data-set, but not in the MNIST data-set.
the last fully-connected layer. Each of these models can be            TransNet can only benefit from transformations that do not
used separately, as we do later on.                                    occur naturally in the target data-set, in order for each head
                                                                       to learn a well defined and non-overlapping classification
3.2. Training algorithm
                                                                       task. Transformations that occur naturally in the data-set
   Our method uses SGD with a few modifications to min-                are often used for data augmentation, as by definition they
imize the transformation loss (4), as detailed in Alg. 1. Re-          do not change the data domain.
lying on the fact that each batch is sampled i.i.d. from D,
we can prove (see Lemma 1) the desirable property that
the sampled loss LT (hT , B) is an unbiased estimator for              Dihedral group D4 . As mentioned earlier, the TransNet
the transformation loss LT (hT , X). This justifies the use            model is defined by a set of input transformations T. We
of Alg. 1 to optimize the transformation loss.                         constrain T to be a subset of the dihedral group D4 , which
                                                                       includes reflections and rotations by multiplications of 90◦ .
Lemma 1. Given batch B, the sampled transformation loss                We denote a horizontal reflection by m and a counter-
LT (hT , B) is an unbiased estimator for the transformation            clockwise 90◦ rotation by r. Using these two elements we
loss LT (hT , X).                                                      can express all the D4 group elements as {ri , m ◦ ri | i ∈
                                                                       0, 1, 2, 3}. These transformations were chosen because, as
Proof.                                                                 mentioned in [10], their application is relatively efficient
                                                                       and does not leave artifacts in the image (unlike scaling or
         EB∼Db [LT (hT , B)]                                           change of aspect ratio).
           = EB∼Db [Et∈T [L(ht , t(B))]]                                  Note that these transformations can be applied to any 3D
                                                iid                    tensor while operating on the height and width dimensions,
           = Et∈T [EB∼Db [L(ht , t(B))]]    (B ∼ Db )       (5)
                                                                       including an input image as well as the model’s kernels.
           = Et∈T [L(ht , t(X))]                                       When applying a transformation t to the model’s weights
           = LT (hT , X)                                               θ, denoted t(θ), the notation implies that t operates on the
                                                                       model’s kernels separately, not affecting other layers such
                                                                       as the fully-connected ones (see Fig. 2).

                                                                   3
sic model by appending additional heads:
                                                                                                                          k
                                                                                                                          Y
                                                                                                hT,θ = {gt ◦ linv ◦             ci }t∈T       (7)
                                                                                                                          i=1

                                                                                We denote the parameters of a fully-connected or a convo-
                                                                                lutional layer by subscripts of w (weight) and b (bias), e.g.
                                                                                g(x) = gw · x + gb .
                                                                                4.1. Transformation compilation
Figure 2: The transformed input convolved with a kernel (upper                    Transformations in the dihedral D4 group satisfy another
path) equals to the transformation applied on the output of the in-             important property, expressed by the following proposition:
put convolved with the inversely transformed kernel (lower path).
                                                                                Proposition 1. Let hθ denote a CNN model where the last
3.4. Model variations                                                           convolutional layer is followed by an invariant layer under
    Once trained, the full TransNet model can be viewed as                      the D4 group. Then any transformation t ∈ D4 applied to
an ensemble of m shared classifiers. Its time complexity                        the input image can be compiled into the model’s weights θ
is linear with the number of heads, almost equivalent to an                     as follows:
ensemble of the base CNN model, since the time needed
                                                                                      ∀t ∈ D4      ∀x ∈ X :        hθ (t(x)) = ht−1 (θ) (x)   (8)
to apply each one of the D4 transformations to the input
is negligible as compared to the time needed for the model                      Proof. By induction on k we can show that:
to process the input. Differently, the space complexity is
almost equivalent to the space complexity of only one base                                   k
                                                                                             Y                         k
                                                                                                                       Y
CNN model1 .                                                                                       ci ◦ t(x) = t ◦           t−1 (ci )(x)     (9)
    We note that one can prune each one of the model’s                                       i=1                       i=1
heads, thus leaving a smaller ensemble of up to m classi-
                                                                                (see Fig. 2). Plugging (9) into (6), we get:
fiers. A useful reduction prunes all the model’s heads except
one, typically the one corresponding to the identity trans-                                                k
                                                                                                           Y
formation, which yields a regular CNN that is equivalent                         hθ (t(x)) = g ◦ linv ◦          ci ◦ t(x)
in terms of time and space complexity to the base architec-                                                i=1
ture used to build the TransNet model. Having done so, we                                                        k
                                                                                                                 Y
can evaluate the effect of the TransNet architecture’s and                                  = g ◦ linv ◦ t ◦           t−1 (ci )(x)
its training algorithm’s inductive bias solely on the training                                                   i=1
procedure, by comparing the pruned TransNet to the base                                                    k
                                                                                                           Y
CNN model (see Section 5).                                                                  = g ◦ linv ◦         t−1 (ci )(x) (linv ◦ t = linv )
                                                                                                           i=1
4. Theoretical Analysis                                                                     = ht−1 (θ) (x)
  In this section we analyze theoretically the TransNet
model. We consider the following basic CNN architecture:
                                          k                                     Implication. The ResNet model [11] used in our exper-
                                          Y
                       hθ = g ◦ linv ◦          ci                   (6)        iments satisfies the pre-condition in the proposition stated
                                          i=1                                   above, since it contains a GAP layer [19] after the last con-
                                                                                volutional layer, and GAP is invariant under D4 .
where g denotes a fully-connected layer, linv denotes an
invariant layer under the D4 transformations group (e.g. a                      4.2. Single vs. multiple headed model
global average pooling layer - GAP), and {ci }i∈[k] denote                          In order to acquire intuition regarding the inductive bias
convolutional layers2 . The TransNet model extends the ba-                      implied by training algorithm Alg. 1, we consider two cases,
    1 Each additional head adds 102K (∼0.45%) and 513K (∼0.90%) extra           a single and a double headed model, trained with the same
parameters to the basic ResNet18 model when training CIFAR-100 and              training algorithm. A single headed model is a special case
ImageNet-200 respectively.                                                      of the full multi-headed model, where all the heads share
    2 While each convolutional layer may be followed by ReLU and Batch

Normalization [14] layers, this doesn’t change the analysis so we obviate       weights ht (t(x)) = h(t(x)) ∀t, and the loss in line 5 of
                                                                                                                Pb
the extra notation.                                                             Alg. 1 becomes L(h, B) = 1b k=1 `(h, t(xk ), yk ).

                                                                            4
As it’s hard to analyze non-convex deep neural networks,                       head. Each gi outputs a vector of size 2. The data-set
we focus on a simplified framework and consider a con-                            X = {(x1 , y1 ), (x2 , y2 )} consists of 2 examples:
vex optimization problem where the loss function is convex                                                                        
w.r.t. the model’s parameters θ. We also assume that the                                 1 1 1                               0 0 0
model’s transformations in T form a group3 .                                       x1 = 0 0 0 , y1 = 1, x2 = 0 0 0 , y2 = 2
                                                                                         0 0 0                               1 1 1
Single Headed model Analysis. In this simplified case, we
can prove the following strict proposition:                                       Note that x2 = t2 (x1 )4 .
Proposition 2. Let hθ denote a CNN model satisfying the                              Now, assume the model’s convolutional layer c is com-
pre-condition of Prop. 1, and T ⊂ D4 a transformations                            posed of 2 invariant kernels under T, and denote it by cinv .
group. Then the optimal transformation loss LT (see Eq. 4)                        Let i ∈ 1, 2, then:
is obtained by invariant model’s weights under the transfor-
                                                                                    hi (x2 ) = hi (t2 (x1 )) = gi ◦ GAP ◦ cinv ◦ t2 (x1 )
mations T. Formally:                                                                                                                                    (10)
                                                                                              = gi ◦ GAP ◦ cinv (x1 ) = hi (x1 )
 ∃θ0 : (∀t ∈ T : θ0 = t(θ0 )) ∧ (θ0 ∈ arg min LT (θ, X))
                                                     θ                            In this case both heads predict the same output for both
                                                                                  inputs with different labels, thus:
Proof. To simplify the notations, henceforth we let θ de-
note the model hθ .
                                                                                             L(hi , ti (X)) > 0 =⇒ LT (hT,θ , X) > 0
LT (θ, X)
                                                                                  In contrast, by setting cw = (x1 , x2 ), cb = (0, 0), which
    = Et∈T [L(θ, t(X))]                                                           isn’t invariant under T, as well as:
    = Et∈T [E(x,y)∼D [`(θ, t(x), y)]]                                                                                                
                                                                                            1 0              0             0 1             0
    = Et∈T [E(x,y)∼D [`(t−1 (θ), x, y)]] (by Prop. 1)                             g1,w =           , g1,b =       g2,w =         , g2,b =     ,
                                                                                            0 1              0             1 0             0
    = E(x,y)∼D [Et∈T [`(t−1 (θ), x, y)]]
    ≥ E(x,y)∼D [`(Et∈T [t−1 (θ)], x, y)] (Jensen’s inequality)                    we obtain:

    = E(x,y)∼D [`(θ̄, x, y)]         (θ̄ := Et∈T [t(θ))],       T = T−1 )                    L(hi , ti (X)) = 0 =⇒ LT (hT,θ , X) = 0.
    = L(θ̄, X)
                                                                                  We may conclude that the optimal model’s kernels aren’t
    = Et∈T [L(t−1 (θ̄), X)]         (θ̄ is invariant under T)
                                                                                  invariant under T, as opposed to the claim of Prop. 2.
    = Et∈T [L(θ̄, t(X))] (by Prop. 1)
    = LT (θ̄, X)                                                                  Discussion. The intuition we derive from the analysis above
                                                                                  is that the training algorithm (Alg. 1) implies an invariant
Above we use the fact that θ̄ is invariant under T since T is                     inductive bias on the model’s kernels as proved in the sin-
a group and thus t0 T = T, hence:                                                 gle headed model, while not strictly enforcing invariance as
                                                                                  shown by the counter example of the double headed model.
t0 (θ̄) = t0 (Et∈T [t(θ)]) = Et∈T [t0 ◦t(θ)] = Et∈T [t(θ)] = θ̄
                                                                                  5. Experimental Results

Double headed model. In light of Prop. 2 we now present                           data-sets. For evaluation we used the 5 image classification
a counter example, which shows that Prop. 2 isn’t true for                        data-sets detailed in Table 1. These diverse data-sets allow
the general TransNet model.                                                       us to evaluate our method across different image resolutions
                                                                                  and number of predicted classes.
Example 1. Let T = {t1 = r0 , t2 = m ◦ r2 } ⊂ D4
denote the transformations group consisting of the iden-                          Implementation Details. We employed the ResNet18 [11]
tity and the vertical reflection transformations. Let hT,θ =                      architecture for all the data-sets except for ImageNet-200,
{hi = gi ◦ GAP ◦ c}2i=1 denote a double headed TransNet                           which was evaluated using the ResNet50 architecture (see
model, which comprises a single convolutional layer (1                            Appendix A for more implementation details).
channel in and 2 channels out), followed by a GAP layer                           Notations.
and then 2 fully-connected layers {gi }2i=1 , one for each
                                                                                      4 This example may seem rather artificial, but in fact this isn’t such a
   3T  being a group is a technical constraint needed for the analysis, not       rare case. E.g., the airplane and the ship classes, both found in the CIFAR-
required by the algorithm.                                                        10 data-set, that share similar blue background.

                                                                              5
Name                      Classes       Train/Test        dim                 comparing the ”Tm-CNN” models with the ”base-CNN”
                                               Samples                              model, see Table 3. Despite the fact that the full TransNet
      CIFAR-10 [15]                10          50K/10K           32                 model processes the (transformed) input m times more as
      CIFAR-100 [15]              100          50K/10K           32                 compared to the ”base-CNN” model, its architecture is not
      ImageNette [13]              10           10K/4K          224                 significantly larger than the base-CNN’s. The full TransNet
      ImageWoof [13]               10           10K/4K          224                 adds to the ”base-CNN” a negligible number of parame-
      ImageNet-200                200         260K/10K          224                 ters, in the form of its multiple heads1 . Clearly the full
                                                                                    TransNet model improves the accuracy as compared to the
Table 1: The data-sets used in our experiments. The dimension of                    ”base-CNN” model, and also as compared to the pruned
each example, a color image, is dim×dim×3 pixels. ImageNette                        TransNet model. Thus, if the additional runtime complexity
represents 10 easy to classify classes from ImageNet [4], while Im-
                                                                                    during test is not an issue, it is beneficial to employ the full
ageWoof represents 10 hard to classify classes of dog breeds from
ImageNet. ImageNet-200 represents 200 classes from ImageNet
                                                                                    TransNet model during test time. In fact, one can process
(same classes as in [17]) of full size images.                                      the input image once, and then choose whether to continue
                                                                                    processing it with the other heads to improve the prediction,
                                                                                    all this while keeping roughly the same space complexity.
   • ”base CNN” - a regular convolutional neural network,
     identical to the TransNet model with only the head cor-                        Ensembles: models with similar time complexity, dif-
     responding to the identity transformation.                                     ferent space complexity. Here we evaluate ensembles of
   • ”PTm-CNN” - a pruned TransNet model trained with                               pruned TransNet models, and compare them to a single full
     m heads, where a single head is left and used for pre-                         TransNet model that can be seen as a space-efficient ensem-
     diction5 . It has the same space and time complexity as                        ble: full TransNet generates m predictions with only 1/m
     the base CNN.                                                                  parameters, where m is the number of TransNet heads. Re-
                                                                                    sults are shown in Fig. 3. Clearly an ensemble of pruned
   • ”Tm-CNN” - a full TransNet model trained with m
                                                                                    TransNet models is superior to an ensemble of base CNN
     heads, where all are used for prediction. It has roughly
                                                                                    models, suggesting that the accuracy gain achieved by the
     the same space complexity1 and m times the time com-
                                                                                    pruned TransNet model doesn’t overlap with the accuracy
     plexity as compared to the base CNN.
                                                                                    gain achieved by using an ensemble of classifiers. Fur-
To denote an ensemble of the models above, we add a suffix                          thermore, we observe that the full TransNet model exhibits
of a number in parentheses, e.g. T2-CNN (3) is an ensemble                          competitive accuracy results, with 2 and 3 heads, as com-
of 3 T2-CNN models.                                                                 pared to an ensemble of 2 or 3 base CNN models respec-
                                                                                    tively. This is achieved while utilizing 1/2 and 1/3 as
5.1. Models accuracy, comparative results                                           many parameters respectively.
   We now compare the accuracy of the ”base-CNN”,
”PTm-CNN” and ”Tm-CNN” models, where m = 2, 3, 4
denotes the number of heads of the TransNet model, and
their ensembles, across all the data-sets listed in Table 1.
Models with the same space and time complexity. First,
we evaluate the pruned TransNet model by comparing the
”PTm-CNN” models with the ”base-CNN” model, see Ta-
ble 2. Essentially, we evaluate the effect of using the
TransNet model only for training, as the final ”PTm-CNN”
models are identical to the ”base-CNN” model regardless
of m. We can clearly see the inductive bias implied by
the training procedure. We also see that TransNet training
improves the accuracy of the final ”base-CNN” classifier
across all the evaluated data-sets.
Models with similar space complexity, different time                                Figure 3: Model accuracy as a function of the number of instances
complexity. Next, we evaluate the full TransNet model by                            (X-axis) processed during prediction. Each instance requires a
                                                                                    complete run from input to output. An ensemble includes: m in-
   5 In our experiments we chose the head associated with the identity (r 0 )
                                                                                    dependent base CNN classifiers for ”CNN”; m pruned TransNet
transformation when evaluating a pruned TransNet. Note, however, that we
                                                                                    trained with 2 heads for ”PT2-CNN”; and one TransNet model
could have chosen the best head in terms of accuracy, as it follows from
Prop. 1 that its transformation can be compiled into the model’s weights.           with m heads, where m is the ensemble size, for ”Tm-CNN”.

                                                                                6
MODEL           CIFAR-10        CIFAR-100         ImageNette       ImageWoof        ImageNet-200
               base-CNN       95.57 ± 0.08     76.56 ± 0.16      92.97 ± 0.16     87.27 ± 0.15      84.39 ± 0.07
               PT2-CNN        95.99 ± 0.07     79.33 ± 0.15      93.84 ± 0.14     88.09 ± 0.30      85.17 ± 0.10
               PT3-CNN        95.87 ± 0.04     79.08 ± 0.06      94.15 ± 0.16     87.79 ± 0.11      84.97 ± 0.95
               PT4-CNN        95.73 ± 0.05     77.98 ± 0.17      93.94 ± 0.06     85.81 ± 0.79      84.02 ± 0.71
Table 2: Accuracy of models with the same space and time complexity, comparing the Base CNN with pruned TransNet models ”PTm-
CNN”, where m = 2, 3, 4 denotes the number of heads in training. Mean and standard error for 3 repetitions are shown.

               MODEL           CIFAR-10        CIFAR-100         ImageNette       ImageWoof        ImageNet-200
               base-CNN       95.57 ± 0.08     76.56 ± 0.16      92.97 ± 0.16     87.27 ± 0.15      84.39 ± 0.07
               T2-CNN         96.22 ± 0.10     80.35 ± 0.06      94.02 ± 0.13     88.36 ± 0.33      85.47 ± 0.14
               T3-CNN         96.33 ± 0.06     80.92 ± 0.08      94.39 ± 0.07     88.79 ± 0.25      85.68 ± 0.20
               T4-CNN         96.17 ± 0.01     79.94 ± 0.16      94.67 ± 0.06     87.05 ± 0.75      85.54 ± 0.11
Table 3: Accuracy of models with similar space complexity and different time complexity, comparing the Base CNN with full TransNet
models. With m denoting the number of heads, chosen to be 2,3 or 4, the prediction time complexity of the respective TransNet model
”Tm-CNN” is m times larger than the base CNN. Mean and standard error for 3 repetitions are shown.

Accuracy vs. generalization. In Fig. 3 we can see that 2             TransNet models exhibit better generalization when com-
heads improve the model’s performance across all data-sets,          pared to the base CNN model. Furthermore, the generaliza-
3 heads improve it on most of the data-sets, and 4 heads ac-         tion improvement increases with the number of TransNet
tually reduce performance on most data-sets. We hypothe-             model heads, which are only used for training and then
size that too many heads impose too strict an inductive bias         pruned. The observed narrowing of the generalization gap
on the model’s kernels. Thus, although generalization is im-         occurs because, although the TransNet model slightly in-
proved, test accuracy is reduced due to insufficient variance.       creases the training loss, it more significantly decreases the
Further analysis is presented in the next section.                   test loss as compared to the base CNN.
5.2. Generalization
   We’ve seen in Section 5.1 that the TransNet model,
whether full or pruned, achieves better test accuracy as com-
pared to the base CNN model. This occurs despite the fact
that the transformation loss LT (hT , X) minimized by the
TransNet model is more demanding than the loss L(h, X)
minimized by the base CNN, and appears harder to opti-
mize. This conjecture is justified by the following Lemma:
Lemma 2. Let hT denote a TransNet model that obtains
transformation loss of a := LT (hT , X). Then there exists a
reduction from hT to the base CNN model h that obtains a
loss of at most a, i.e. L(h, X) ≤ a.                                 Figure 4: CIFAR-100 results. Left panel: learning curve of the
Proof. a = LT (hT , X) = Et∈T [L(hθt , t(X))], so there              Base CNN model (”base-CNN”) and a pruned TransNet model
must be a transformation t ∈ T s.t. L(hθt , t(X)) ≤ a.               (”PT2-CNN”). Right panel: generalization score, test-train loss
                                                                     ratio, measured for the base-CNN model and various pruned
Now, one can compile the transformation t into hθt (see
                                                                     TransNet models with a different number of heads.
Prop. 1) and get a base CNN: h̃ = ht−1 (θt ) which obtains
L(h̃, X) = L(ht−1 (θt ) , t(X)) = L(hθt , t(X)) ≤ a.
                                                                         We note that better generalization does not necessarily
   Why is it, then, that the TransNet model achieves overall         imply a better model. The ”PT4-CNN” model generalizes
better accuracy than the base CNN? The answer lies in its            better than any other model (see right panel of Fig. 4), but
ability to achieve a better generalization.                          its test accuracy is lower as seen in Table 2.
   In order to measure the generalization capability of a
                                                                     5.3. Kernel invariance
model w.r.t. a data-set, we use the ratio between the test-
set and train-set loss, where a lower ratio indicates better            What characterizes the beneficial inductive bias implied
generalization. As illustrated in Fig. 4, clearly the pruned         by the TransNet model and its training algorithm Alg. 1?.

                                                                 7
To answer this question, we investigate the emerging invari-       TransNet models exhibits much higher invariance level as
ance of kernels in the convolutional layers of the learned         compared to the base CNN. This phenomenon is robust to
network, w.r.t. the TransNet transformations set T.                the metric used in the IS definition, with similar results
   We start by introducing the ”Invariance Score” (IS),            when using ”Pearson Correlation” or ”Cosine Similarity”.
which measures how invariant a 3D tensor is w.r.t. a trans-        The increased invariance in the last convolutional layer is
formations group. Specifically, given a convolutional kernel       monotonically increasing with the number of heads in the
denoted by w (3D tensor) and a set of transformations group        TransNet model, which is consistent with the generalization
T, the IS score is defined as follows:                             capability of these models (see Fig 4).

              IS(w, T) := min kw − uk                  (11)
                              u∈IN VT

where IN VT is the set of invariant kernels (same shape as
w) under T, i.e. IN VT := {u : u = t(u) ∀t ∈ T}.
Lemma 3. arg minu∈IN VT kw − uk = Et∈T [t(w)]
Proof. Let u be an invariant tensor under T. Define
                  2
f (u) := kw − uk . Note that arg minu∈IN VT kw − uk =
arg minu∈IN VT f (u).

                    2
  f (u) = kw − uk
                               2                                   Figure 5: CIFAR-100 results, plotting the distribution of the IS
        = Et∈T [kw − t(u)k ] (u is invariant under T)
                                   2                               scores (mean and std) for the kernels in each layer of the different
        = Et∈T [ t−1 (w) − u           ]                           models. Invariance is measured w.r.t. the group of 90◦ rotations.
                               2
        = Et∈T [kt(w) − uk ] (T = T−1 )
                size(w)
                   X
        = Et∈T [          (t(w)i − ui )2 ]
                   i=1

Where index i runs over all the tensors’ elements. Finally,
we differentiate f to obtain its minimum:
     ∂f
         = Et∈T [−2(t(w)i − ui )] = 0
     ∂ui
      =⇒ ui = Et∈T [[t(w)i ] =⇒ u = Et∈T [t(w)]

Lemma 3 gives a closed-form expression for the IS gauge:

             IS(w, T) = kw − Et∈T [t(w)]k              (12)
                                                                   Figure 6: CIFAR-100 results, plotting the full distribution of the
   Equipped with this gauge, we can inspect the invari-            IS scores for the kernels in the last (17-th) layer of the different
ance level of the model’s kernels w.r.t. a transformations         models. Invariance is measured w.r.t. the group of 90◦ rotations.
group. Note that this measure allows us to compare the full
                                                                      The generalization improvement achieved by the
TransNet model with the base CNN model, as both share
                                                                   TransNet model, as reported in Section 5.2, may be ex-
the same convolution layers. Since the transformations of
                                                                   plained by this increased level of invariance, as highly in-
the TransNet model don’t necessarily form a group, we use
                                                                   variant kernels have fewer degrees of freedom, and should
the minimal group containing these transformations - the
                                                                   therefore be less prone to overfit.
group of all rotations {ri }4i=1 .
   In Fig. 5 we can see that the full TransNet model ”T2-          5.4. Ablation Study
CNN” and the base CNN model demonstrate similar in-
variance level in all the convolutional layers but the last           Our method consists of 2 main components - the
one. In Fig. 6, where the distribution of the IS score over        TransNet architecture as well as the training algorithm
the last layer of 4 different models is fully shown, we can        Alg. 1. To evaluate the accuracy gain of each component
more clearly see that the last convolutional layer of full         we consider two variations:

                                                               8
MODEL           CIFAR-10        CIFAR-100         ImageNette        ImageWoof         ImageNet-200
                base-CNN       95.57 ± 0.08     76.56 ± 0.16      92.97 ± 0.16      87.27 ± 0.15       84.39 ± 0.07
                Alg. only      93.85 ± 0.63     76.64 ± 0.69      92.60 ± 0.07      87.64 ± 0.30       80.58 ± 0.08
                Arch. only     95.68 ± 0.05     76.98 ± 0.13      93.49 ± 0.03      87.40 ± 0.74       84.47 ± 0.13
                PT2-CNN        95.99 ± 0.07     79.33 ± 0.15      93.84 ± 0.14      88.09 ± 0.30       85.17 ± 0.10
Table 4: Accuracy of the ablation study models with the same space and time complexity, these 4 models enable us to evaluate the effect
of the TransNet architecture as well as the TransNet algorithm separately. Mean and standard error for 3 repetitions are shown.

   • Architecture only: in this method we train the multi-            Acknowledgements
     headed architecture (2 in this case) by feeding each
     head the same un-transformed batch (equivalent to a                 This work was supported in part by a grant from the Is-
     TransNet model with the multi-set of {id, id} transfor-          rael Science Foundation (ISF) and by the Gatsby Charitable
     mations). Prediction is retrieved from a single head             Foundations.
     (similar to PT2-CNN).
                                                                      References
   • Algorithm only: in this method we train the base (one             [1] Christopher Clark and Amos Storkey. Training deep convo-
     headed) model by the same algorithm Alg. 1. (This                     lutional neural networks to play go. In International confer-
     model was also considered in the theoretical part 4.2,                ence on machine learning, pages 1766–1774, 2015. 2
     termed single headed model.)                                      [2] Taco Cohen and Max Welling. Group equivariant convo-
                                                                           lutional networks. In International conference on machine
                                                                           learning, pages 2990–2999, 2016. 2
   We compare the two methods above to the ”base-CNN”
                                                                       [3] Ekin D Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasude-
regular model and the complete model ”PT2-CNN”, see
                                                                           van, and Quoc V Le. Autoaugment: Learning augmentation
Table 4. We can see that using only one of the compo-
                                                                           strategies from data. In Proceedings of the IEEE conference
nents doesn’t yield any significant accuracy gain. This sug-               on computer vision and pattern recognition, pages 113–123,
gest that the complete model benefits from both compo-                     2019. 2
nents working together: the training algorithm increases the           [4] Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li,
model kernel’s invariance on the one hand, while the multi-                and Li Fei-Fei. Imagenet: A large-scale hierarchical image
heads architecture encourage the model to capture meaning-                 database. In 2009 IEEE conference on computer vision and
ful orientation information on the other hand.                             pattern recognition, pages 248–255. Ieee, 2009. 6
                                                                       [5] Sander Dieleman, Jeffrey De Fauw, and Koray Kavukcuoglu.
                                                                           Exploiting cyclic symmetry in convolutional neural net-
6. Summary                                                                 works. arXiv preprint arXiv:1602.02660, 2016. 2
                                                                       [6] Sander Dieleman, Kyle W Willett, and Joni Dambre.
    We introduced a model inspired by self-supervision,                    Rotation-invariant convolutional neural networks for galaxy
which includes a base CNN model attached to multiple                       morphology prediction. Monthly notices of the royal astro-
heads, each corresponding to a different transformation                    nomical society, 450(2):1441–1459, 2015. 2
from a fixed set of transformations. The self-supervised as-           [7] Carl Doersch, Abhinav Gupta, and Alexei A Efros. Unsuper-
pect of the model is crucial, as the chosen transformations                vised visual representation learning by context prediction. In
must not occur naturally in the data. When the model is                    Proceedings of the IEEE international conference on com-
pruned back to match the base CNN, it achieves better test                 puter vision, pages 1422–1430, 2015. 2
accuracy and improved generalization, which is attributed              [8] Alexey Dosovitskiy, Philipp Fischer, Jost Tobias Springen-
to the increased invariance of the model’s kernels in the last             berg, Martin Riedmiller, and Thomas Brox. Discriminative
layer. We observed that excess invariance, while improving                 unsupervised feature learning with exemplar convolutional
generalization, eventually curtails the test accuracy.                     neural networks. IEEE transactions on pattern analysis and
                                                                           machine intelligence, 38(9):1734–1747, 2015. 2
    We evaluated our model on various image data-sets, ob-             [9] Robert Gens and Pedro M Domingos. Deep symmetry net-
serving that each data-set achieves its own optimal ker-                   works. In Advances in neural information processing sys-
nel’s invariance level, i.e. there’s no optimal number of                  tems, pages 2537–2545, 2014. 2
heads for all data-sets. Finally, we introduced an invari-            [10] Spyros Gidaris, Praveer Singh, and Nikos Komodakis. Un-
ance score gauge (IS), which measures the level of invari-                 supervised representation learning by predicting image rota-
ance achieved by the model’s kernels. IS may be leveraged                  tions. arXiv preprint arXiv:1803.07728, 2018. 2, 3
to determine the optimal invariance level, as well as poten-          [11] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.
tially function as an independent regularization term.                     Deep residual learning for image recognition. In Proceed-

                                                                  9
ings of the IEEE conference on computer vision and pattern           [29] Karl Weiss, Taghi M Khoshgoftaar, and DingDing Wang. A
       recognition, pages 770–778, 2016. 4, 5, 10                                survey of transfer learning. Journal of Big data, 3(1):9, 2016.
[12]   Dan Hendrycks, Mantas Mazeika, Saurav Kadavath, and                       2
       Dawn Song. Using self-supervised learning can improve                [30] Fa Wu, Peijun Hu, and Dexing Kong.                 Flip-rotate-
       model robustness and uncertainty. In Advances in Neural                   pooling convolution and split dropout on convolution neu-
       Information Processing Systems, pages 15663–15674, 2019.                  ral networks for image classification.         arXiv preprint
       2                                                                         arXiv:1507.08754, 2015. 2
[13]   Jeremy Howard. Imagewang. 6                                          [31] Xiaohua Zhai, Avital Oliver, Alexander Kolesnikov, and Lu-
[14]   Sergey Ioffe and Christian Szegedy. Batch normalization:                  cas Beyer. S4l: Self-supervised semi-supervised learning. In
       Accelerating deep network training by reducing internal co-               Proceedings of the IEEE international conference on com-
       variate shift. arXiv preprint arXiv:1502.03167, 2015. 2, 4                puter vision, pages 1476–1485, 2019. 2
[15]   Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple           [32] Richard Zhang, Phillip Isola, and Alexei A Efros. Colorful
       layers of features from tiny images. 2009. 6                              image colorization. In European conference on computer
[16]   Gustav Larsson,          Michael Maire,         and Gregory               vision, pages 649–666. Springer, 2016. 2
       Shakhnarovich. Learning representations for automatic                [33] Zhun Zhong, Liang Zheng, Guoliang Kang, Shaozi Li, and
       colorization. In European conference on computer vision,                  Yi Yang. Random erasing data augmentation. In AAAI, pages
       pages 577–593. Springer, 2016. 2                                          13001–13008, 2020. 2
[17]   Ya Le and Xuan Yang. Tiny imagenet visual recognition
       challenge. CS 231N, 7, 2015. 6                                       Appendix
[18]   Chen-Yu Lee, Saining Xie, Patrick Gallagher, Zhengyou
       Zhang, and Zhuowen Tu. Deeply-supervised nets. In Ar-                A. Implementation details
       tificial intelligence and statistics, pages 562–570, 2015. 10
                                                                                We employed the ResNet [11] architecture, specifically
[19]   Min Lin, Qiang Chen, and Shuicheng Yan. Network in net-
       work. arXiv preprint arXiv:1312.4400, 2013. 4                        the ResNet18 architecture for all the data-sets except for the
[20]   Joseph L Mundy, Andrew Zisserman, et al. Geometric in-               ImageNet-200 which was evaluated using the ResNet50 ar-
       variance in computer vision, volume 92. MIT press Cam-               chitecture. It’s important to notice that we haven’t changed
       bridge, MA, 1992. 2                                                  the hyper-parameters used by the regular CNN architecture
[21]   Jiquan Ngiam, Zhenghao Chen, Daniel Chia, Pang W Koh,                which TransNet is based on. This may strengthen the results
       Quoc V Le, and Andrew Y Ng. Tiled convolutional neu-                 as one may fine tune these hyper-parameters to suit best the
       ral networks. In Advances in neural information processing           TransNet model.
       systems, pages 1279–1287, 2010. 2                                        We used a weight decay of 0.0001 and momentum of
[22]   Mehdi Noroozi and Paolo Favaro. Unsupervised learning                0.9. The model was trained with a batch size of 64 for all
       of visual representations by solving jigsaw puzzles (2016).          the data-sets except for ImageNet-200 where we increased
       arXiv preprint arXiv:1603.09246. 2                                   the batch size to 128. We trained the model for 300 epochs,
[23]   Deepak Pathak, Philipp Krahenbuhl, Jeff Donahue, Trevor              starting with a learning rate of 0.1, divided by 10 at the 150
       Darrell, and Alexei A. Efros. Context encoders: Feature              and 225 epochs, except for the ImageNet-200 model which
       learning by inpainting, 2016. 2
                                                                            was trained for 120 epochs, starting with a learning rate of
[24]   Pascal Paysan, Reinhard Knothe, Brian Amberg, Sami
                                                                            0.1, divided by 10 at the 40 and 80 epochs. We normalized
       Romdhani, and Thomas Vetter. A 3d face model for pose
       and illumination invariant face recognition. In 2009 Sixth
                                                                            the images as usual by subtracting the image’s mean and
       IEEE International Conference on Advanced Video and Sig-             dividing by the image’s standard deviation (color-wise).
       nal Based Surveillance, pages 296–301. Ieee, 2009. 2                     We employed a mild data augmentation scheme - hori-
[25]   Ling Shao, Fan Zhu, and Xuelong Li. Transfer learning for            zontal flip with probability of 0.5. For the CIFAR data-sets
       visual categorization: A survey. IEEE transactions on neural         we padded each dimension by 4 pixels and cropped ran-
       networks and learning systems, 26(5):1019–1034, 2014. 2              domly (uniform) a 32×32 patch from the enlarged image
[26]   Laurent Sifre and Stéphane Mallat. Rotation, scaling and            [18] while for the ImageNet family data-sets we cropped
       deformation invariant scattering for texture discrimination.         randomly (uniform) a 224×224 patch from the original im-
       In Proceedings of the IEEE conference on computer vision             age.
       and pattern recognition, pages 1233–1240, 2013. 2                        In test time, we took the original image for the CIFAR
[27]   Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya            data-sets and a center crop for the ImageNet family data-
       Sutskever, and Ruslan Salakhutdinov. Dropout: a simple way           sets. The prediction of each model is the mean of the
       to prevent neural networks from overfitting. The journal of
                                                                            model’s output on the original image and a horizontally
       machine learning research, 15(1):1929–1958, 2014. 2
                                                                            flipped version of it. Note that a horizontal flip occurs nat-
[28]   Pavan Turaga and Rama Chellappa. Locally time-invariant
                                                                            urally in every data-set we use for evaluation and therefore
       models of human activities using trajectories on the grass-
       mannian. In 2009 IEEE Conference on Computer Vision and              isn’t associated with any of the TransNet model’s heads that
       Pattern Recognition, pages 2435–2441. IEEE, 2009. 2                  we evaluate.

                                                                       10
You can also read