Preservation of the Global Knowledge by Not-True Self Knowledge Distillation in Federated Learning - arXiv

Page created by Dolores Cortez
 
CONTINUE READING
Preservation of the Global Knowledge by Not-True Self Knowledge Distillation in Federated Learning - arXiv
Preservation of the Global Knowledge by Not-True
                                            Self Knowledge Distillation in Federated Learning

                                                         Gihun Lee, Yongjin Shin, Minchan Jeong, and Se-Young Yun
                                                                                   KAIST
                                                        {opcrisis, yj.shin, mcjeong, yunseyoung}@kaist.ac.kr
arXiv:2106.03097v1 [cs.LG] 6 Jun 2021

                                                                                       Abstract

                                                 In Federated Learning (FL), a strong global model is collaboratively learned by
                                                 aggregating the clients’ locally trained models. Although this allows no need to
                                                 access clients’ data directly, the global model’s convergence often suffers from
                                                 data heterogeneity. This paper suggests that forgetting could be the bottleneck
                                                 of global convergence. We observe that fitting on biased local distribution shifts
                                                 the feature on global distribution and results in forgetting of global knowledge.
                                                 We consider this phenomenon as an analogy to Continual Learning, which also
                                                 faces catastrophic forgetting when fitted on the new task distribution. Based on our
                                                 findings, we hypothesize that tackling down the forgetting in local training relives
                                                 the data heterogeneity problem. To this end, we propose a simple yet effective
                                                 framework Federated Local Self-Distillation (FedLSD), which utilizes the global
                                                 knowledge on locally available data. By following the global perspective on local
                                                 data, FedLSD encourages the learned features to preserve global knowledge and
                                                 have consistent views across local models, thus improving convergence without
                                                 compromising data privacy. Under our framework, we further extend FedLSD to
                                                 FedLS-NTD, which only considers the not-true class signals to compensate noisy
                                                 prediction of the global model. We validate that both FedLSD and FedLS-NTD
                                                 significantly improve the performance in standard FL benchmarks in various setups,
                                                 especially in the extreme data heterogeneity cases.

                                        1   Introduction
                                        Nowadays, massive data is collected from edge devices such as phones, vehicles, and facilities. By
                                        decoupling the ability to learn from the need to access private client data directly, Federate Learning
                                        (FL) proposes a distributed learning paradigm that enables learning a global model while preserves
                                        clients’ data privacy [24, 25, 36, 37]. In FL, clients independently train local models using their
                                        private local data, and the global server aggregates them into a single global model. In this process,
                                        most of the computation is performed by the client devices, and the global server only updates the
                                        aggregated model parameters and distributes them to clients [1, 56].
                                        Most FL algorithms are based on FedAvg [36], which aggregates the locally trained model parameters
                                        by weighted averaging proportional to the amount of local data that each client had. While various
                                        FL algorithms have been proposed until recently, they all conduct parameter averaging in certain
                                        manners [29, 34, 43, 51, 52, 58]. Although this aggregation scheme empirically works well and
                                        provides a conceptually ideal framework when all client devices are active and i.i.d. distributed
                                        (a.k.a. LocalSGD), the data heterogeneity problem [30, 63] is one of the major challenges for FL
                                        applications to be widespread.
                                        Since the client generates its own data, the data is not identically distributed. More precisely, the local
                                        data across clients are drawn from heterogeneous underlying distributions; thereby, locally available
                                        data fail to represent the overall global distribution. Despite its inevitable occurrence in many real-
                                        world scenarios, the heterogeneity of local data not only makes the theoretical analysis difficult

                                        Preprint. Under review.
Preservation of the Global Knowledge by Not-True Self Knowledge Distillation in Federated Learning - arXiv
(a) Continual Learning         (b) Federated Learning (FedAvg)     (c) Federated Learning (FedLSD)
                        Figure 1: An overview of forgetting in learning scenarios.
[19, 21, 30, 63], but also degrades many FL algorithms’ performance [16, 28, 29, 33]. By resolving
the data heterogeneity problem, the learning becomes more robust against partial participation of
clients [30, 36], and the communication cost is also reduced by faster convergence [51, 58].
Interestingly, Continual Learning [44, 49] faces a similar challenge. In CL, a learner model is contin-
uously updated on a sequence of tasks, with the objective to remember whole tasks. Unfortunately,
since the data distribution of each task is heterogeneous, learning on different tasks often results in
catastrophic forgetting [23, 32, 35, 40], whereby fitting on a new task interferes with the learned
representations on the previous task. As a consequence, the model parameters drift away from the
area where the past knowledge is desirably preserved (Figure 1a).
Our first conjecture is that such forgetting also exists in FL scenarios as illustrated in Figure 1b. The
global model is updated on randomly sampled local clients at each round, but the local datasets have
different distributions from the previous round. We focus on this analogy of the FL scenario in the
sense of catastrophic forgetting. To empirically verify this analogy, we examine the prediction of
the global model before and after the communication round. We find that the prediction consistency
between two global models degrades as the data heterogeneity increases. We then analyze the
representation change after local training and figure out fitting on the biased local distribution causes
the trained model’s features to be shifted towards the locally available region, resulting in forgetting
of global knowledge.
Based on our findings, we hypothesize that tackling down forgetting of global knowledge can relieve
the data heterogeneity problem in FL. Inspired by the prior works in CL [15, 32], which distills
predictions on old tasks while learning the new tasks, we propose a simple yet effective FL framework
Federated Local Self Distillation (FedLSD). FedLSD utilizes the global model’s prediction on local
data to preserve global knowledge, which the local distribution cannot represent(Figure 1c). More
specifically, the local model self-distills the distributed global model’s prediction on locally available
data. This approach resembles distillation-based knowledge preservation in Continual Learning
[15, 32], where new task data is evaluated for the old task and its prediction is maintained.
However, merely following the global prediction may obtain sub-optimal performance in FL. Unlike
CL, where each task is sufficiently learned before observing a new task, FL update with only a
few local steps due to its distributed setups. This induces the global model’s prediction to be noisy
(the predicted label is different from the true label), especially in the early communication rounds.
To compensate for it, we further extend FedLSD to Federated Local Self Not-True Distillation
(FedLS-NTD) by following not-true class signals only. We demonstrate the preservation of global
knowledge of both algorithms and their efficacy on standard FL benchmarks with various setups.
Our contributions can be summarized as follows:
      • Based on the analogy of FL scenarios on CL, We emphasize the forgetting of global
        knowledge in FL and suggest that feature shifting induced fitting on biased local distribution
        causes the forgetting (Section 2).
      • We propose FedLSD to preserve global knowledge. To compensate the noisy prediction of
        global model in early phase, We further extend FedLSD as FedLS-NTD. Unlike previous
        distillation-based approaches in FL, our methods does not require any additional local
        information or auxiliary data (Section 3).
      • We validate the efficacy of FedLSD and FedLS-NTD in various setups. Using Centered
        Kernel Analysis (CKA) and T-SNE visualization, we demonstrate that both algorithms enjoy
        remarkable feature consistency during local training. (Section 4).

                                                    2
Preservation of the Global Knowledge by Not-True Self Knowledge Distillation in Federated Learning - arXiv
2       Forgetting global knowledge in FL
To confirm our conjecture on forgetting, we first consider how the prediction of the global model
changes. If the data heterogeneity induces forgetting, the global model’s prediction after aggregation
may be less consistent when compared to before it is distributed. To examine it, we measure
Intersection of Union (IoU) of mutually correct samples between Distributed Global (DG) model and
                                                |DGcorrect ∩ AGcorrect |
Aggregated Global (AG) model as IoUcorrect = |DG   correct ∪ AGcorrect |
                                                                         . The result is summarized in Table 1.
As expected, the prediction consistency degrades as data heterogeneity increases (less number of
classes per client), along with growing variance. This implies that forgetting also occurs in FL, and it
is closely related to data heterogeneity 1 .

        Table 1: Prediction consistency of CIFARCNN model by varying heterogeneity.
                                                 max class / client
       CIFAR10
                               2              3          5           8         10 (iid)
 (DG vs AG) IoUcorrect 0.270±0.092 0.326±0.095 0.375±0.076 0.405±0.074 0.475±0.014
                                                 max class / client
      CIFAR100
                               5             10         20          40        100 (iid)
 (DG vs AG) IoUcorrect 0.257±0.081 0.314±0.069 0.334±0.070 0.357±0.062 0.418±0.037

To figure out forgetting of knowledge in the global model, we now analyze how the representation
on the global distribution changes during local training. To this end, we design a straightforward
experiment that shows the change of features on the unit hypersphere. More specifically, we modified
the network architecture to map CIFAR-10 input data to 2-dimensional vectors and normalize them to
be aligned on the unit hypersphere S 1 = {x ∈ R2 : kxk2 = 1}. We then estimate their probability
density function. The global model is learned for 100 rounds of communication, but this time, on
homogeneous locals (i.i.d. distributed) and distributed to heterogeneous locals with different local
distributions. The result is in Figure 2.

                                                 (a) FedAvg

                                             (b) FedLSD (ours)

Figure 2: Features of CIFAR-10 test samples on S 2 . We plot the feature distribution with Gaussian
kernel density estimation (KDE) in R2 and arctan(y, x) for each point (x, y) ∈ S 1 . The distributed
global model (first column) is trained on heterogeneous locals (middle 3 columns) and aggregated by
parameter averaging (last column). The values in the parenthesis are test accuracy.

    1
     We distribute CIFAR-10 data to 100 clients and sample 10 clients per round, where each device has the
same amount of samples but with maximally two classes. The experiment is conducted for 300 communication
rounds on a CIFARCNN(2-conv + 3-fc layers) model, the same architecture used in FedAvg [36]. More details
are in Appendix.

                                                      3
Preservation of the Global Knowledge by Not-True Self Knowledge Distillation in Federated Learning - arXiv
Obviously, the feature vectors from the global model learned with homogeneous locals are uniformly
distributed to hypersphere S 1 . However, after the local training, the entire feature region is shifted
towards the locally available classes as in Figure 2a. Thus, when aggregated the local models, the
global model only reflects the feature regions where was dominant in the local models, and the
knowledge of unseen classes is forgotten, significantly degrading the global test accuracy.
Our analysis shows that feature shifting on the global distribution, where the local set fails to
represent, causes the forgetting of the global knowledge. Based on our findings, we hypothesize
that preserving the knowledge on the global distribution prevents forgetting and relives the data
heterogeneity problem in FL.
3   Federated Local Self Distillation
In this section, we first introduce the proposed Federated Local Self Distillation (FedLSD) and
its key features. We then further extend FedLSD to Federated Local Self Not-True Distillation
(FedLS-NTD), which shows more improved performance.

Algorithm 1 Federated Local Self-Distillation (FedLSD)
Input: weight w0 , total rounds T , local epochs E, dataset X × Y, sampled client set K (t) in round
t, learning rate α
Initialize w0 for global server
for each communication round t = 1, · · · , T do
   Server samples clients K (t) and broadcasts w̃t ← wt to all clients
   for each client k ∈ K (t) in parallel do
      Construct prediction pair Xk × Yk
      for Local Steps i = 1 · · · E do
         for Batches j = 1 · · · B do             
           w̃kt ← w̃kt − α∇w LFedLSD [Xk × Yk ]j (w̃kt ) [Equation 1 or Equation 2]
         end for
      end for
   end for
   Upload w̃kt to server
   Server Aggregation :wt+1 ← |K1(t) | k∈K (t) w̃kt
                                          P

end for
Server output : wT

3.1 Local self distillation
The core idea of FedLSD is to preserve the global view on the local data during local training.
The shifting of the learned features occurs by fitting on the biased local distribution, which cannot
represent the global distribution. Therefore, preserving information about where local data should
be located in the global distribution prevents the locally learned features from being deviated from
the feature space of the global model. To this end, FedLSD conducts local-side self-distillation: the
prediction of the distributed global model on the local data is distilled during local training, using the
following loss function LFedLSD as a linear combination of cross-entropy LCE and the distillation loss
LLSD .
                 LFedLSD = (1 − β) · LCE (q, py ) + β · LLSD (qτ , qτg ) (0 < β < 1) ,                 (1)

where q, q g is the softmax probability of each local client model and global model at last round, and
py is the one-hot label. The distillation loss LLSD is defined as follows:
            
                          exp(zc /τ )
            qτ (c) = PCi=1 exp(zi /τ )
                                                                    XC              
                                                                                       qτ (c)
                                                                                               
                                                             g             g
                                          ,     L    (q
                                                  LSD τ τ, q   ) = −     q   (c) log             .
                               g
            qτg (c) = PCexp(zc /τg)
            
                                                                     c=1
                                                                           τ
                                                                                       qτg (c)
                         i=1 exp(z /τ )
                                 i

Here, q(c) stands for the prediction probability on class c ∈ [C] as a softmax function of logits z and
C is the number of classes, and py (c) denotes one-hot true label. We omit data x and model weight
w and write q(x, w)(c) and py (x)(c) as q(c) and py (c) for simplicity. For distillation loss, the logits
are divided by τ to soften the prediction probability. Note that the LLSD is the KL-Divergence loss
between local prediction and global prediction.

                                                    4
Preservation of the Global Knowledge by Not-True Self Knowledge Distillation in Federated Learning - arXiv
FedLSD resembles memory-based CL algorithms [15, 32, 42], which exploits distillation loss to
avoid forgetting by preserving the knowledge on old tasks. Although CL scenarios generally assume
a small amount of old task data to replay (episodic memory), such an assumption is not allowed in
FL due to the data privacy issue. Instead, the global model’s predictions can be utilized as a reference
of the previous data distribution to induce a similar effect to that of using the episodic memory in CL.
We analyze the effect of LLSD in terms of weight divergence. Proposition 1 implies that increasing
β in local objective reduces the weight divergence wt − w(GD)   t
                                                                      . The corresponding experiment is
in Figure 3, where the local models trained using LLSD (τ=1) and varying β.

     Figure 3: Weight divergence of CIFARCNN on CIFAR-10 datasets (max class/client = 2).

Proposition 1. Consider the FedAvg[36] setting with uniformly weighted K = |K (t) | local clients
holding class distribution Pk(t) . We train the global and local clients by LLSD (τ =1) with convex
class landscape w 7→ EDc [LLSD (·, w)], where Dc is the data distribution of c-th class. Assume that
       ∂z
w 7→ ∂w   (·, w) is λJc -Lipschitz /B-bounded and w 7→ q(·, w) is λqc -Lipschitz with k · kL2 (Dc ) .
                            √
Then, if 0 < η < 2 minc ( 2λJc + Bλqc )-1 , the weight divergence between averaged local wt+1 and
          t+1
global w(GD)   after one round is bounded by the following inequality.
                                                                                                      
                                                 E−1
                          ηB  X
                                            (t)
                                                  X
  kwt+1 − w(GD)t+1
                                                      q(·, w̃kt,i ) − β q g + (1 − β) py
                                                                                         
                    k≤              χ(P Pk )                                                    (t)
                                                                                                       ,
                           K                      i=0
                                                                                           L2 Dk
                           k∈R(t)

         (t)
where Dk is the data on k-th local and q g = q(x, wt ) is the global prediction at last round.

3.2 Local self not-true distillation
Although FedLSD reduces forgetting and effectively stabilizes
weight divergence, merely following the global prediction is
sub-optimal to prevent forgetting. At first, unlike CL, where
each task is sufficiently learned before observing a new task, FL
updates with only a few local steps due to its distributed setups.
This induces the global model’s prediction to be noisy (the
predicted label is different from the true label), especially in the
early communication rounds. Secondly, to prevent forgetting,
information not included in the local distribution should have a
higher priority. When the teacher prediction is noisy, however,
the KD-loss cannot sufficiently enjoy the benefits from temper-      Figure 4: Not-True Distillation
ature softening. As discussed in [22], using a small temperature
is rather performs better when teacher is noisy, neglecting the negative noise effect .
To tackle down both noisy global model’s predictions and local information priority issues, we further
extend FedLSD by distilling the global view on the relationship between not-true classes only, as
illustrated in Figure 4. With the decoupling between signals for the true class and the not-true classes,
the feature shifting problem can now be relieved without compromising learning on the local data.
More specifically, the logits belong to the true class is not used, when obtaining the softmax prediction
probability for both local and global model, and the distillation loss closes the gap between them.
The proposed Federated Local Not-True Distillation (FedLS-NTD) is formed similar to the FedLSD
loss function:
               LFedLS-NTD = (1 − β) · LCE (q, py ) + β · LNTD (q̃τ , q̃τg )   (0 < β < 1) ,           (2)

                                                    5
Preservation of the Global Knowledge by Not-True Self Knowledge Distillation in Federated Learning - arXiv
where the q̃τ and q̃τg are the softmax of not-true class logits and LNTD is the KL-Divergence loss
between them:
    
                      exp(zc /τ )
    q̃τ (c) = PCi=1,i6=y exp(zi /τ )
    
                                                                           C                            
                                                                 g
                                                                           X
                                                                                     g           q̃τ (c)
                                       (∀c 6
                                           = y), L NTD (q̃τ , q̃   ) = −          q̃   (c) log              .
                      exp(zcg /τ )
                                                                 τ                  τ
                                                                                                 q̃τg (c)
    q̃τg (c) = PC
    
                                  g                                     c=1,c6=y
                   i=1,i6=y exp(z /τ )
                                 i

We analyze the advantage of FedLS-NTD over FedLSD, using the loss function as follows:
                              LLSD→NTD = (1 − β) · LCE (q, py ) + β · L(λ) ,
                     where L(λ) = (1 − λ) · LLSD (qτ , qτg ) + λ · LNTD (q̃τ , q̃τg ) .

Here, λ moves the loss function between FedLSD(at λ = 0) and FedLS-NTD (at λ = 1), where other
hyperparameters are fixed (τ = 1, β = 0.3).
To improve the performance of the global model, an aggregated global model should 1) preserve the
correct prediction and 2) discard the incorrect prediction of the previous global model. we examine
the prediction consistency IoUcorrect and IoUIncorrect between Distributed Global (DG) and Aggregated
Global (AG) models (Figure 5). Since FedLS-NTD does not follow the true-class signal, the ability
to discard incorrect prediction is enhanced while preserves correct prediction on par with FedLSD.

    Figure 5: Prediction consistency of CIFARCNN on CIFAR-10 datasets (max class/client = 2).

4    Experiment
We evaluate our FedLSD and FedLS-NTD on various FL scenarios with data heterogeneity. We
compare our results with FedAvg[36] and FedProx[29] as baselines and validate how the proposed
method improves the FL performance. VggNet-11[47] and ResNet-8[13] model are used to test
the algorithms. We used three different datasets: FEMNIST[4], CIFAR-10, and CIFAR-100 [27].
For CIFAR datasets, we introduce label distribution heterogeneity, where the clients have the same
number of samples but a limited number of classes (2 for CIFAR-10, and 20 for CIFAR-100). The
experimental details are in Appendix A.

4.1 Performances on data heterogeneous FL scenarios
We compare the algorithms on different datasets to validate the efficacy of the proposed algorithms.
The results are in Table 2. In the experiment, FedLSD and FedLS-NTD consistently outperform
baseline algorithms. We also investigate how they perform when combined with FedProx. Although
FedProx does not always improve the FedAvg baseline in our experiment, the combination with
FedLSD or FedLS-NTD brings additional performance improvement.

             Figure 6: Learning curves for Res8 model corresponds to results in Table 2.

                                                     6
Preservation of the Global Knowledge by Not-True Self Knowledge Distillation in Federated Learning - arXiv
Table 2: Test accuracy after target rounds and number of rounds to reach the target accuracy

                                             FEMNIST          CIFAR10 (δ = 2)     CIFAR100 (δ = 20)
Model           Algorithm
                                        test acc round (80%) test acc round (60%) test acc round (40%)
      FedAvg[36]              86.0 ±0.69               63    54.2 ±5.62    466      50.4 ±0.90        154
      FedProx[29]             85.9 ±0.30               68    55.8 ±3.23    406      51.5 ±0.52        139
      FedLSD (ours)           86.9 ±0.23               58    58.0 ±1.91    425      52.1 ±0.24        138
 Res8
      FedLS-NTD (ours)        86.2 ±0.53               71    63.5 ±0.95    299      52.9 ±0.14        127
      FedProx[29] + FedLSD    85.4 ±0.06               94    60.5±2.52     397      52.1 ±0.11        134
      FedProx[29] + FedLS-NTD 85.5 ±0.09               89    64.8 ±1.15    265      53.3 ±0.07        128
      FedAvg[36]              84.5 ±0.37              118    67.3 ±1.54    302      42.5 ±0.52        240
      FedProx[29]             85.1 ±0.16              132    61.1 ±0.42    254      42.3 ±0.04        220
      FedLSD (ours)           85.6 ±0.16              129    68.6 ±1.15    246      46.3 ±0.16        193
VGG11
      FedLS-NTD (ours)        84.8 ±0.70              129    69.6 ±2.01    190      45.4 ±0.23        193
      FedProx[29] + FedLSD    84.6 ±0.18              129    67.8 ±1.29    227      46.4 ±0.03        198
      FedProx[29] + FedLS-NTD 84.8 ±0.21              137    68.1 ±1.17    180      45.1 ±0.08        198
  δ: indicates the (max class / client)

While FedLS-NTD generally performs better than FedLSD, the performance gap varies depending
on the datasets and model architectures. We analyze that since FedLS-NTD learns not-true class
relationships only, FedLSD may perform better when the prediction noise of the global model is not
significant. We suggest that FedLS-NTD is more effective when the intensity of data heterogeneity
is high, thereby following the global model possesses more risk to have biased distribution. For
example, in FEMNIST datasets, almost all classes are included in the local data, but the hand-written
data that belongs to the same class have different forms across the clients. In this case, FedLSD and
FedLS-NTD show slow convergence in the early phase while they improve the final test accuracy.
We further investigate the performance by varying the data heterogeneity (Table 3) and sampling ratio
(Table 4). The results show that the proposed algorithms significantly outperform the baselines in
data heterogeneity cases. Interestingly, while the baseline algorithms fail to learn and diverge when
only one class is available in the local, FedLSD and FedLS-NTD are able to learn even under such
extreme heterogeneity. Although the FL scenarios where only positive labels are accessible in the
local is not our main scope, we argue that this result verifies that our proposed algorithms effectively
encourage the preservation of global knowledge during local training.
Table 3: Res8 on CIFAR10: Data heterogeneity Table 4: Res8 on CIFAR10: Sampling ratio
                          max class / client                                 sampled clients / N
  Algorithm                                                 Algorithm
                   1       2     3     5 10 (iid)                         0.05 0.1 0.3 0.5 1.0
   FedAvg       failure   54.2   64.6   76.9   86.8           FedAvg      44.0   54.2   65.5   67.8   68.7
   FedProx      failure   55.8   66.6   77.8   86.5           FedProx     49.7   55.8   65.0   67.7   67.5
   FedLSD        22.9     58.0   70.6   80.3   86.6           FedLSD      52.9   58.0   63.3   65.1   67.0
 FedLS-NTD       21.3     63.5   74.9   81.7   86.0         FedLS-NTD     58.1   63.5   69.4   71.0   72.7

4.2 Preservation of global knowledge
To empirically measure how well the global knowledge is preserved after local training, we examine
the trained local models’ test accuracy on global distribution (Figure 7). If fitting on the local
distribution preserves global knowledge, the trained local models could generalize well on the global
distribution. However, since forgetting occurs during local training, the baseline algorithms show only
around 20% test accuracy on global distribution. Since maximally two classes are available locally
for each client is, this indicates that the trained model forgets the knowledge on the other classes,
which have not been observed in the local training. On the other hand, FedLSD and FedLS-NTD
show significantly higher local test accuracy while improving the final global accuracy. This implies
that the FedLSD and FedLS-NTD preserve global distribution information during local training, even
if the local distribution cannot explicitly represent it.

                                                      7
Preservation of the Global Knowledge by Not-True Self Knowledge Distillation in Federated Learning - arXiv
Figure 7: A comparison of final global test accuracy and local test accuracy corresponds to Table 2.
The local test accuracy is an average of sampled clients for the last five rounds.

4.3 Analysis on feature similarity
CKA analysis To demonstrate FedLSD and FedLS-NTD enjoys the feature consistency, we com-
pare the CKA [26] between global model and locally trained models. More specifically, we measure
the CKA of Distributed Global(DG), Local(L), and Aggregated Global(AG) models as 4-types: (a)
DG vs L, (b) L vs L, (c) L vs AG, and (d) DG vs AG as plotted in Figure 8. As illustrated, the feature
region is dramatically shifted during local training of FedAvg algorithm. In FedProx, the drift of the

Figure 8: CKA[26] of Res8 model on CIFAR10 datasets (max class/client = 2). The values in the
parenthesis indicates the test accuracy of the final global model.
local model is relieved than FedAvg, but still suffers the inconsistency between DG vs. AG features.
On the other hand, FedLSD and FedLS-NTD maintain the feature space not to be shifted during
local training; hence the DG and AG models share the feature region. Notably, although the feature
similarity of FedLS-NTD is lower at the early phase of learning, it reaches and even exceeds the
feature similarity of FedLSD at the end of learning.

T-SNE visualization Contrary to FedAvg, our FedLSD effectively binds the features even when
the local distribution is heavily biased. In Figure 9a, the test set features from local models are
shifted when FedAvg is used (blue). However, the local features learned by FedLSD (red) share same
space with the global model’s features (green). When viewed by classes Figure 9a, these features are
well-aligned by their semantics. Note that the selected locals are identical for both algorithms.

            (a) colored by different algorithms                         (b) colored by classes
Figure 9: T-SNE visualization CIFAR-10 test set features on ResNet-8 model. The global model is learned for
100 communication round using FedAvg, and distributed to 10 locals to conduct local training. The T-SNE is
performed all together for the test sample features of global and 10 local models.

                                                    8
Preservation of the Global Knowledge by Not-True Self Knowledge Distillation in Federated Learning - arXiv
5   Related work
Federated Learning (FL) Federated Learning is proposed to update a global model by asyn-
chronous SGD [9], while the local data is kept in users’ devices [24, 25]. A standard algorithm in FL
is FedAvg [36], which aggregates trained local models by averaging their parameters. Although its
effectiveness has been largely discussed in i.i.d. settings [18, 48, 54, 55, 59], many algorithms obtain
the sub-optimal when the distributed data is heterogeneous [30, 63]. A wide range of variants of
FedAvg has been proposed to overcome such a problem. One line of work is the local-side modifica-
tion. For example, Fedprox [29] adds a proximal term in local objective, and SCAFFOLD [19] uses
control variates to correct the client-drift. In FedNova [52], the locally trained models are normalized
to tackle down local objective inconsistency. Likewise, the server-side modification is also considered
in various works. For instance, AFL[38] optimizes a mixture of local distribution in a centralized
manner. In PFNM [60] and FedMA [51], the optimal transport to match the layer-wise order of
individual neurons by conducting post-training after aggregation. Our work can be considered as a
local-side modification, which focuses on the local objective.

Continual Learning (CL) Continual Learning [44, 49] is a learning paradigm that updates a
sequence of tasks instead of training on the whole datasets at once. In CL, the main challenge is
to avoid catastrophic forgetting [11], whereby training on the new task interferes with the learned
representation of the previous tasks. The prior works in CL largely focuses on overcoming this
catastrophic forgetting [5, 6, 15]. Existing methods can be mainly classified into two categories. At
first, parameter-based approaches [2, 6, 23, 39, 62] regularize the change of parameters, so that task
parameters do not interfere with each other. Secondly, memory-based approaches [3, 5, 15, 32, 42]
maintain the knowledge of previous tasks by replaying a small episodic memory data. In our work,
we consider the latter to relieve forgetting in FL scenarios. Although the episodic data of previous
tasks (local data from previous rounds) cannot be stored due to the data privacy issue, we regard that
following the distributed global model’s prediction has a similar effect of replaying memory.

Knowledge Distillation in FL Knowledge Distillation (KD) is proposed to transfer knowledge
from a trained model to another [14]. KD shows impressive success with its enormous variants
[10, 50, 57, 64, 66]. In FL, most approaches based on KD aim to make the global model learn
the ensemble of locally trained models [7, 12, 17, 34, 65]. For instance, FD [17] collects averaged
local logits per label and distributes them to regularize local training. In FedDF [34], the averaged
logits of local models are used for distillation at the server. FedAUX [45] expands the idea of FD
to conducting self-supervised pre-training. FedBE [7] introduces the Bayesian view on sampling
global model for aggregation. Another line of work is using KD to reduce the communication cost
[12, 46, 53, 65]. In [12], one-round FL using the ensemble of local models is proposed. Instead of
distilling from direct ensembles of clients, DOSFL [65], applied the dataset distillation [53]. The
previous works require either additional local information [17, 45, 65] or auxiliary data [7, 12, 34, 46].
However, as discussed in [29], the partial global sharing of local information violates privacy, and
using globally-shared auxiliary data should be cautiously generated or collected. On the other hand,
our approach does not require any local information or auxiliary data.

6   Conclusion
In this paper, we suggest that forgetting phenomenon occurs in Federated Learning scenarios,
as in Continual Learning. By analyzing the features after local training, we find that fitting on
heterogeneous local distributions shifts the features on global distribution, thereby cause the forgetting
of global knowledge. To relieve this problem, we propose FedLSD by maintaining the global view
on local data. We further extend it to FedLS-NTD by focusing on the not-true class relationship only.
In the experiments, we validate the efficacy of proposed algorithms and demonstrate their effect on
global knowledge preservation and feature consistency.

                                                    9
Preservation of the Global Knowledge by Not-True Self Knowledge Distillation in Federated Learning - arXiv
7   Broader Impact
We believe that Federated Learning is an important learning paradigm that enables privacy-preserving
ML. This paper suggests the forgetting issue and introduces the methods to relive it without com-
promising data privacy. The insight behind this work may inspire new researches. However, the
proposed method maintains the knowledge outside of local distribution in the global model. This
implies that if the global model is biased, the trained local model is more prone to have a similar
tendency. This should be considered for ML participators.
References
 [1] Mohammed Aledhari, Rehma Razzak, Reza M Parizi, and Fahad Saeed. Federated learning: A
     survey on enabling technologies, protocols, and applications. IEEE Access, 8:140699–140725,
     2020.
 [2] Rahaf Aljundi, Francesca Babiloni, Mohamed Elhoseiny, Marcus Rohrbach, and Tinne Tuyte-
     laars. Memory aware synapses: Learning what (not) to forget. In Proceedings of the European
     Conference on Computer Vision (ECCV), pages 139–154, 2018.
 [3] Rahaf Aljundi, Punarjay Chakravarty, and Tinne Tuytelaars. Expert gate: Lifelong learning
     with a network of experts. In Proceedings of the IEEE Conference on Computer Vision and
     Pattern Recognition, pages 3366–3375, 2017.
 [4] Sebastian Caldas, Sai Meher Karthik Duddu, Peter Wu, Tian Li, Jakub Konečnỳ, H Brendan
     McMahan, Virginia Smith, and Ameet Talwalkar. Leaf: A benchmark for federated settings.
     arXiv preprint arXiv:1812.01097, 2018.
 [5] Arslan Chaudhry, Albert Gordo, Puneet K Dokania, Philip Torr, and David Lopez-Paz. Using
     hindsight to anchor past knowledge in continual learning. arXiv preprint arXiv:2002.08165,
     2020.
 [6] Arslan Chaudhry, Naeemullah Khan, Puneet K Dokania, and Philip HS Torr. Continual learning
     in low-rank orthogonal subspaces. arXiv preprint arXiv:2010.11635, 2020.
 [7] Hong-You Chen and Wei-Lun Chao. Fedbe: Making bayesian model ensemble applicable to
     federated learning, 2021.
 [8] G. Cohen, S. Afshar, and A Tapson, J.and van Schaik. Emnist: an extension of mnist to
     handwritten letters. arXivpreprint arXiv:1702.05373, 2017.
 [9] Jeffrey Dean, Greg Corrado, Rajat Monga, Kai Chen, Matthieu Devin, Mark Mao, Marc' aurelio
     Ranzato, Andrew Senior, Paul Tucker, Ke Yang, Quoc Le, and Andrew Ng. Large scale
     distributed deep networks. In F. Pereira, C. J. C. Burges, L. Bottou, and K. Q. Weinberger,
     editors, Advances in Neural Information Processing Systems, volume 25. Curran Associates,
     Inc., 2012.
[10] Tommaso Furlanello, Zachary Lipton, Michael Tschannen, Laurent Itti, and Anima Anandkumar.
     Born again neural networks. In International Conference on Machine Learning, pages 1607–
     1616. PMLR, 2018.
[11] Ian J Goodfellow, Mehdi Mirza, Da Xiao, Aaron Courville, and Yoshua Bengio. An empirical
     investigation of catastrophic forgetting in gradient-based neural networks. arXiv preprint
     arXiv:1312.6211, 2013.
[12] Neel Guha, Ameet Talwalkar, and Virginia Smith. One-shot federated learning. arXiv preprint
     arXiv:1902.11175, 2019.
[13] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image
     recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition,
     pages 770–778, 2016.
[14] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network.
     arXiv preprint arXiv:1503.02531, 2015.
[15] Saihui Hou, Xinyu Pan, Chen Change Loy, Zilei Wang, and Dahua Lin. Learning a unified
     classifier incrementally via rebalancing. In Proceedings of the IEEE/CVF Conference on
     Computer Vision and Pattern Recognition, pages 831–839, 2019.
[16] Tzu-Ming Harry Hsu, Hang Qi, and Matthew Brown. Measuring the effects of non-identical
     data distribution for federated visual classification. arXiv preprint arXiv:1909.06335, 2019.

                                                10
[17] Eunjeong Jeong, Seungeun Oh, Hyesung Kim, Jihong Park, Mehdi Bennis, and Seong-Lyun
     Kim. Communication-efficient on-device machine learning: Federated distillation and augmen-
     tation under non-iid private data. arXiv preprint arXiv:1811.11479, 2018.
[18] Peng Jiang and Gagan Agrawal. A linear speedup analysis of distributed deep learning with
     sparse and quantized communication. In Proceedings of the 32nd International Conference on
     Neural Information Processing Systems, pages 2530–2541, 2018.
[19] Sai Praneeth Karimireddy, Satyen Kale, Mehryar Mohri, Sashank Reddi, Sebastian Stich, and
     Ananda Theertha Suresh. Scaffold: Stochastic controlled averaging for federated learning. In
     International Conference on Machine Learning, pages 5132–5143. PMLR, 2020.
[20] Hsieh Kevin, Phanishayee Amar, Mutlu Onur, and Gibbons Phillip B. The non-iid data quagmire
     of decentralized machine learning. In Proceedings of the 37th International Conference on
     Machine Learning, Online, PMLR 119, 2020.
[21] Ahmed Khaled, Konstantin Mishchenko, and Peter Richtárik. Tighter theory for local sgd on
     identical and heterogeneous data. In International Conference on Artificial Intelligence and
     Statistics, pages 4519–4529. PMLR, 2020.
[22] Taehyeon Kim, Jaehoon Oh, NakYil Kim, Sangwook Cho, and Se-Young Yun. Comparing
     kullback-leibler divergence and mean squared error loss in knowledge distillation. arXiv preprint
     arXiv:2105.08919, 2021.
[23] James Kirkpatrick, Razvan Pascanu, Neil Rabinowitz, Joel Veness, Guillaume Desjardins,
     Andrei A Rusu, Kieran Milan, John Quan, Tiago Ramalho, Agnieszka Grabska-Barwinska, et al.
     Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of
     sciences, 114(13):3521–3526, 2017.
[24] Jakub Konečnỳ, H Brendan McMahan, Daniel Ramage, and Peter Richtárik. Federated optimiza-
     tion: Distributed machine learning for on-device intelligence. arXiv preprint arXiv:1610.02527,
     2016.
[25] Jakub Konečnỳ, H Brendan McMahan, Felix X Yu, Peter Richtárik, Ananda Theertha Suresh,
     and Dave Bacon. Federated learning: Strategies for improving communication efficiency. arXiv
     preprint arXiv:1610.05492, 2016.
[26] Simon Kornblith, Mohammad Norouzi, Honglak Lee, and Geoffrey Hinton. Similarity of neural
     network representations revisited. In International Conference on Machine Learning, pages
     3519–3529. PMLR, 2019.
[27] Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. Cifar-10 and cifar-100 datasets. URl:
     https://www. cs. toronto. edu/kriz/cifar. html, 6, 2009.
[28] Tian Li, Anit Kumar Sahu, Ameet Talwalkar, and Virginia Smith. Federated learning: Chal-
     lenges, methods, and future directions. IEEE Signal Processing Magazine, 37(3):50–60, 2020.
[29] Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar, and Virginia
     Smith. Federated optimization in heterogeneous networks. arXiv preprint arXiv:1812.06127,
     2018.
[30] Xiang Li, Kaixuan Huang, Wenhao Yang, Shusen Wang, and Zhihua Zhang. On the convergence
     of fedavg on non-iid data. arXiv preprint arXiv:1907.02189, 2019.
[31] Yiying Li, Wei Zhou, Huaimin Wang, Haibo Mi, and Timothy M Hospedales. Fedh2l: Federated
     learning with model and statistical heterogeneity. arXiv preprint arXiv:2101.11296, 2021.
[32] Zhizhong Li and Derek Hoiem. Learning without forgetting. IEEE transactions on pattern
     analysis and machine intelligence, 40(12):2935–2947, 2017.
[33] Wei Yang Bryan Lim, Nguyen Cong Luong, Dinh Thai Hoang, Yutao Jiao, Ying-Chang Liang,
     Qiang Yang, Dusit Niyato, and Chunyan Miao. Federated learning in mobile edge networks: A
     comprehensive survey. IEEE Communications Surveys & Tutorials, 22(3):2031–2063, 2020.
[34] Tao Lin, Lingjing Kong, Sebastian U Stich, and Martin Jaggi. Ensemble distillation for robust
     model fusion in federated learning. arXiv preprint arXiv:2006.07242, 2020.
[35] Michael McCloskey and Neal J Cohen. Catastrophic interference in connectionist networks:
     The sequential learning problem. In Psychology of learning and motivation, volume 24, pages
     109–165. Elsevier, 1989.

                                                 11
[36] Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas.
     Communication-efficient learning of deep networks from decentralized data. In Artificial
     Intelligence and Statistics, pages 1273–1282. PMLR, 2017.
[37] Brendan McMahan and Daniel Ramage. Collaborative machine learning without centralized
     training data. https://research.googleblog.com/2017/04/federated-learning-collaborative.html,
     2017.
[38] Mehryar Mohri, Gary Sivek, and Ananda Theertha Suresh. Agnostic federated learning. In
     International Conference on Machine Learning, pages 4615–4625. PMLR, 2019.
[39] Cuong V Nguyen, Yingzhen Li, Thang D Bui, and Richard E Turner. Variational continual
     learning. arXiv preprint arXiv:1710.10628, 2017.
[40] German I Parisi, Ronald Kemker, Jose L Part, Christopher Kanan, and Stefan Wermter. Continual
     lifelong learning with neural networks: A review. Neural Networks, 113:54–71, 2019.
[41] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan,
     Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmaison, Andreas
     Kopf, Edward Yang, Zachary DeVito, Martin Raison, Alykhan Tejani, Sasank Chilamkurthy,
     Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala. Pytorch: An imperative style, high-
     performance deep learning library. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-
     Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems 32,
     pages 8024–8035. Curran Associates, Inc., 2019.
[42] Sylvestre-Alvise Rebuffi, Alexander Kolesnikov, Georg Sperl, and Christoph H Lampert. icarl:
     Incremental classifier and representation learning. In Proceedings of the IEEE conference on
     Computer Vision and Pattern Recognition, pages 2001–2010, 2017.
[43] Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečnỳ,
     Sanjiv Kumar, and H Brendan McMahan. Adaptive federated optimization. arXiv preprint
     arXiv:2003.00295, 2020.
[44] Mark B Ring. Child: A first step towards continual learning. In Learning to learn, pages
     261–292. Springer, 1998.
[45] Felix Sattler, Tim Korjakow, Roman Rischke, and Wojciech Samek. Fedaux: Leveraging
     unlabeled auxiliary data in federated learning. arXiv preprint arXiv:2102.02514, 2021.
[46] Felix Sattler, Arturo Marban, Roman Rischke, and Wojciech Samek. Communication-efficient
     federated distillation. arXiv preprint arXiv:2012.00632, 2020.
[47] Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for large-scale
     image recognition. arXiv preprint arXiv:1409.1556, 2014.
[48] Sebastian U Stich. Local sgd converges fast and communicates little.            arXiv preprint
     arXiv:1805.09767, 2018.
[49] Sebastian Thrun. Lifelong learning algorithms. In Learning to learn, pages 181–209. Springer,
     1998.
[50] Yonglong Tian, Dilip Krishnan, and Phillip Isola. Contrastive representation distillation. arXiv
     preprint arXiv:1910.10699, 2019.
[51] Hongyi Wang, Mikhail Yurochkin, Yuekai Sun, Dimitris Papailiopoulos, and Yasaman Khazaeni.
     Federated learning with matched averaging. arXiv preprint arXiv:2002.06440, 2020.
[52] Jianyu Wang, Qinghua Liu, Hao Liang, Gauri Joshi, and H Vincent Poor. Tackling the
     objective inconsistency problem in heterogeneous federated optimization. arXiv preprint
     arXiv:2007.07481, 2020.
[53] Tongzhou Wang, Jun-Yan Zhu, Antonio Torralba, and Alexei A Efros. Dataset distillation.
     arXiv preprint arXiv:1811.10959, 2018.
[54] Blake Woodworth, Kumar Kshitij Patel, Sebastian Stich, Zhen Dai, Brian Bullins, Brendan
     Mcmahan, Ohad Shamir, and Nathan Srebro. Is local sgd better than minibatch sgd? In
     International Conference on Machine Learning, pages 10334–10343. PMLR, 2020.
[55] Blake Woodworth, Jialei Wang, Adam Smith, Brendan McMahan, and Nathan Srebro. Graph
     oracle models, lower bounds, and gaps for parallel stochastic optimization. arXiv preprint
     arXiv:1805.10222, 2018.

                                                 12
[56] Qiang Yang, Yang Liu, Tianjian Chen, and Yongxin Tong. Federated machine learning: Concept
     and applications. ACM Transactions on Intelligent Systems and Technology (TIST), 10(2):1–19,
     2019.
[57] Junho Yim, Donggyu Joo, Jihoon Bae, and Junmo Kim. A gift from knowledge distillation:
     Fast optimization, network minimization and transfer learning. In Proceedings of the IEEE
     Conference on Computer Vision and Pattern Recognition, pages 4133–4141, 2017.
[58] Tehrim Yoon, Sumin Shin, Sung Ju Hwang, and Eunho Yang. Fedmix: Approximation of
     mixup under mean augmented federated learning. In International Conference on Learning
     Representations, 2021.
[59] Hao Yu, Rong Jin, and Sen Yang. On the linear speedup analysis of communication efficient
     momentum sgd for distributed non-convex optimization. In International Conference on
     Machine Learning, pages 7184–7193. PMLR, 2019.
[60] Mikhail Yurochkin, Mayank Agarwal, Soumya Ghosh, Kristjan Greenewald, Nghia Hoang,
     and Yasaman Khazaeni. Bayesian nonparametric federated learning of neural networks. In
     International Conference on Machine Learning, pages 7252–7261. PMLR, 2019.
[61] Mikhail Yurochkin, Mayank Agarwal, Soumya Ghosh, Kristjan Greenewald, Nghia Hoang,
     and Yasaman Khazaeni. Bayesian nonparametric federated learning of neural networks. In
     International Conference on Machine Learning, pages 7252–7261. PMLR, 2019.
[62] Friedemann Zenke, Ben Poole, and Surya Ganguli. Continual learning through synaptic
     intelligence. In International Conference on Machine Learning, pages 3987–3995. PMLR,
     2017.
[63] Yue Zhao, Meng Li, Liangzhen Lai, Naveen Suda, Damon Civin, and Vikas Chandra. Federated
     learning with non-iid data. arXiv preprint arXiv:1806.00582, 2018.
[64] Helong Zhou, Liangchen Song, Jiajie Chen, Ye Zhou, Guoli Wang, Junsong Yuan, and Qian
     Zhang. Rethinking soft labels for knowledge distillation: A bias-variance tradeoff perspective.
     arXiv preprint arXiv:2102.00650, 2021.
[65] Yanlin Zhou, George Pu, Xiyao Ma, Xiaolin Li, and Dapeng Wu. Distilled one-shot federated
     learning. arXiv preprint arXiv:2009.07999, 2020.
[66] Xiatian Zhu, Shaogang Gong, et al. Knowledge distillation by on-the-fly native ensemble. In
     Advances in neural information processing systems, pages 7517–7527, 2018.

                                                13
A    Experimental setup details
A.1 Models
Here we provide details on the models we use in our experiments. All the following models are
implemented by PyTorch [41] with slight modification on them.
CifarCNN This model architecture is from [36] on CIFAR experiments, which is composed of two
convolutional layers followed by two fully connected layers, and a linear transformation layer for the
last logits outputs.
VggNet-11 VggNet was proposed by [47]. The architecture consists of stacks of convolutional
layers and max-pooling, followed by three fully connected layers, and again the last linear transforma-
tion layer to produce logits. In our experiments, we use VggNet-11 which is composed of cascaded
11 convolutional layers with max-pooling layers. The original architecture implemented by PyTorch
[41] has batch normalization. However, we skipped batch normalization by following [20].

ResNet-8 also uses a sequence of convolutional layers like VggNet. However, ResNet makes use
of shortcut connections to resolve the vanishing gradient problem and to improve its performance[13].
We implement ours by removing the batch normalization layers.
A.2 Datasets
In the pre-processing of all the image datasets, we followed the normalization process and the standard
data augmentation such as random cropping, horizontal random flipping.
FEMNIST Federated Extended MNIST was built by partitioning EMNIST handwritten image
dataset [8] based on the writer of the digit/characters. 3500 writers generated 62 different classes,
including 10 digits, 26 lowercase, and 26 uppercase. To generated heterogeneous data partitions, we
regarded each writer as local clients and selected 100 clients with more than 250 data samples. 10%
of data for each client is used for the test set.
CIFAR10 & CIFAR100 CIFAR10 and CIFAR100 datasets [27] are the most widely used image
datasets and consist of 10 and 100 classes, respectively. Each dataset consists of 5000 and 500
training samples for one class, respectively. For label heterogeneous partitions(max class / client), we
followed the method of [36]. Since the number of local clients is fixed to 100, one local client has
500 data samples.
A.3 Training setups
We use initial learning rate 0.1 (for CIFARCNN & VGG-11) and 0.3 (for ResNet-8). The learning
rate is decayed with a factor of 0.99 for each 5 communication rounds. The momentum is set as 0.9
in the experiment. The momentum is only used in the local training, which implies the momentum
information is not communicated between the server and local clients. The weight decay is set as 0.
The hyperparameters for learning setups for different classes are in Table 5.

                                   Table 5: Training setup details
                      Datasets            FEMNIST         CIFAR-10       CIFAR-100
                    local epochs                5             5               5
                  local batch size             10            50              50
                data samples / client    300 (average)       500            500
                max class / client (δ)    50 (average)        2              20
                total dataset classes          62            10             100
                  µ (for FedProx)              0.5           0.1            0.05

A.4 Hyperparameters
We use the same hyperparameters for FedLSD and FedLS-NTD across different datasets:
FedLSD(β = 0.3, τ = 3) and FedLS-NTD(β = 0.3, τ = 1). Since the difference between
logits scale is much smaller when the true class logits are excluded, FedLS-NTD requires a smaller
temperature for softening the global model’s prediction.

A.5 GPU resources
We use 1 Titan-RTX and 1 RTX 2080Ti GPU card. Multi-GPU training is not conducted in the paper
experiments.

                                                  14
B    Proof of Proposition 1
On this section, we provide the proof of the following Proposition 1.
Proposition 1. Consider the FedAvg[36] setting with uniformly weighted K = |K (t) | local clients
holding class distribution Pk(t) . We train the global and local clients by LFedLSD (τ=1) with convex
class landscape w 7→ EDc [LFedLSD (·, w)], where Dc is the data distribution of c-th class. Assume
           ∂z
that w 7→ ∂w  (·, w) is λJc -Lipschitz /B-bounded and w 7→ q(·, w) is λqc -Lipschitz with k · kL2 (Dc ) .
                           √
Then, if 0 < η < 2 minc ( 2λJc + Bλqc )-1 , the weight divergence between averaged local wt+1 and
         t+1
global w(GD)  after one round is bounded by the following inequality.
                                                                                                    
                                                 E−1
     t+1      t+1       ηB X                (t) 
                                                  X         t,i        g
                                                                                       
  kw     − w(GD) k ≤               χ(P Pk )           q(·, wk ) − β q + (1 − β) py              (t)
                                                                                                     ,
                         K                        i=0
                                                                                          L2 Dk
                           k∈R(t)
         (t)
where Dk is the data on k-th local and q g = q(x, wt ) is the global prediction at last round.

Proof. First, we arrange our notations and assumptions:
                                           C                            C
                                     (t)
                                           X                            X
•   w ∈ W (bounded) ⊂ Rm ,          Dk =          (t)
                                                 Pk (c) Dc , and   D=         P(c)Dc ,
                                           c=1                          c=1
    where D is the global data distribution and P(c) is the probability for c-th class on D.
                                Z
•   w 7→ EDc [LFedLSD (·, w)] =      LFedLSD (x, w) dP(x) is convex for each c.
                                    Dc
    Here, we note that convexity on D (which is weaker) is in fact enough and more reasonable.
                                                         ∂z
•   For each class c, mapping from weight w to jacobian      (·, w) satisfies:
    sZ                                                  ∂w
                                      2
            ∂z             ∂z
                (x, w2 ) −    (x, w1 ) dP(x) ≤ λJc w2 − w1 2 ∀w1 , w2 ∈ W
        Dc ∂w              ∂w         F
                    sZ
                                      2
                            ∂z
                               (x, w) dP(x) ≤ B                           ∀w ∈ W
                        Dc ∂w         F

•   For each class c, mapping from weight w to prediction on data q(·, w) satisfies:
    sZ
                                 2
           q(x, w2 ) − q(x, w1 ) dP(x) ≤ λqc w2 − w1 2 ∀w1 , w2 ∈ W
        Dc                          F

The differential of the softmax probability is the following:
                             ezc0                 ezc0                         ezc
                                                                                          
       ∂q(c0 )       ∂
                =                        = z1                  1(c = c0 ) − z1
         ∂zc       ∂zc ez1 + · · · + ezC   e + · · · + ezC                 e + · · · + ezC
                = q(c0 ) 1(c = c0 ) − q(c) , i.e.
                                           

       ∂q(c0 )
               = −q(c0 ) (1c0 − q).
         ∂z
Next, we provide the differential of LCE , LLSD and LFedLSD with respect to the logit z.
∂LCE    ∂(− log(q(true)))
      =                   = −(py − q)
 ∂z            ∂z
             C                     X C
∂LLSD     ∂ X g               q(i)                ∂(− log(q(i))
       =        −q (i) log     g
                                    =     q g (i)               = −(q g − q)
 ∂z      ∂z i=1              q (i)    i=1
                                                       ∂z
∂LFedLSD     ∂LLSD              ∂LCE
         =β           + (1 − β)      = (q − (βq g + (1 − β)py ))
  ∂z           ∂z                ∂z
Now, we prove the following lemma.

                                                        15
           √               
Lemma 1. Given w1 , w2 ∈ W and η ∈                       0, 2 minc ( 2λJc + Bλqc )-1 , the following inequality
holds for all β ∈ [0, 1] and τ = 1. We simply denote LFedLSD (τ = 1, β)(w) as LFedLSD (w) if ther is
no conflict.
                                                                   
         w2 − η∇w ED LFedLSD (w2 ) − w1 − η∇w ED LFedLSD (w1 )              ≤ w2 − w1 2
                                                                                                 2
                                                  
∵) First, we prove that w 7→ ∇w ED LFedLSD (w) is Lipschitz. The gradient of LSD loss has the
following form:
                                 ∂LFedLSD ∂z                                  ∂z
                   ∇w LFedLSD =                  = q − (βq g + (1 − β)py          .
                                    ∂z      ∂w                                 ∂w
Therefore,
                                                  
 ∇w ED LFedLSD (·, w2 ) − ∇w ED LFedLSD (·, w1 ) 2
                                                                                         
 = ∇w ED LFedLSD (w2 ) − LFedLSD (w1 ) 2 = ED ∇w LFedLSD (w2 ) − ∇w LFedLSD (w1 ) 2
                                                                                          
                       g               ∂z                           g             ∂z
 = ED (q(w2 ) − (βq + (1 − β)py )          (w2 ) − (q(w1 ) − (βq + (1 − β)py )       (w1 )
                                       ∂w                                         ∂w          2
         "                                                                                    #
                                         ∂z             ∂z                            ∂z
 = ED (q(w2 ) − (βq g + (1 − β)py )           (w2 ) −      (w1 ) − (q(w1 ) − q(w2 ))     (w1 )
                                         ∂w            ∂w                            ∂w
                                                                                                  2
                                                              
                                       ∂z            ∂z                                  ∂z
 ≤ ED (q(w2 ) − (βq g + (1 − β)py )         (w2 ) −      (w1 )     +ED (q(w1 ) − q(w2 ))      (w1 )
                                       ∂w            ∂w           2                      ∂w         2
                                                                   
                                         ∂z            ∂z
 ≤ ED q(w2 ) − (βq g + (1 − β)py 2           (w2 ) −       (w1 )
                                        ∂w             ∂w         F
                                          
                              ∂z
   + ED q(w1 ) − q(w2 ) 2         (w1 )
                              ∂w         F
     C                                                                       
    X
                                g                    ∂z            ∂z
 =      P(c) EDc q(w2 ) − (βq + (1 − β)py 2               (w2 ) −     (w1 )
    c=1
                                                     ∂w            ∂w       F

       C                                               
      X                                   ∂z
   +       P(c) EDc q(w1 ) − q(w2 ) 2         (w1 )
      c=1
                                          ∂w         F

    C
    X                                                                   ∂z            ∂z
≤         P(c)       q(·, w2 ) − (βq g (·) + (1 − β)py (·)          2
                                                                           (·, w2 ) −    (·, w1 )
    c=1
                                                                        ∂w            ∂w             F   L1 (Dc )
        C
        X                                    ∂z
    +         P(c)     q(w1 ) − q(w2 )   2
                                                (w1 )
        c=1
                                             ∂w            F    L1 (Dc )
    C
    X                                                                                ∂z            ∂z
≤         P(c)       q(·, w2 ) − (βq g (·) + (1 − β)py (·)          2
                                                                                        (·, w2 ) −    (·, w1 )
    c=1                                                                 L2 (Dc )     ∂w            ∂w            F   L2 (Dc )
        C
        X                                                  ∂z
    +         P(c)     q(w1 ) − q(w2 )   2
                                                              (w1 )                                  (Hölder’s Inequality)
        c=1                                  L2 (Dc )      ∂w              F   L2 (Dc )
    C
    X            √                                        XC      √            
≤         P(c)        2λJc kw2 − w1 k2 + λqc kw2 − w1 k2 B =   P(c)    2λJc + Bλqc kw2 − w1 k2
    c=1                                                                        c=1
                                !
                √
≤       max ( 2λJc + Bλqc ) kw2 − w1 k2 =: Γ kw2 − w1 k2
        c∈[C]

Since w 7→ ED [LFedLSD (·, w)] is convex, Γ-Lipschitzness of differential implies Γ-smoothness, i.e.
                                                                                2
                               ∇w ED LFedLSD (·, w2 ) − ∇w ED LFedLSD (·, w1 ) 2
                  D                                                             E
             ≤ Γ ∇w ED LFedLSD (·, w2 ) − ∇w ED LFedLSD (·, w1 ) , w2 − w1

                                                               16
Finally, if 0 < η < 2/Γ,

                                                         2
w2 − η∇w ED LFedLSD (·, w2 ) − w1 − η∇w ED LFedLSD (·, w1 )
                                                               2
                  D                                                          E
            2                                                  
= kw2 − w1 k2 − 2η ∇w ED LFedLSD (·, w2 ) − ∇w ED LFedLSD (·, w1 ) , w2 − w1
                                                      2
+ η 2 ∇w ED LFedLSD (·, w2 ) − ∇w ED LFedLSD (·, w1 ) 2
                                  

                     2                                2                     2                                      2
= kw2 − w1 k2 − 2ηh· · · i + η 2 k· · ·k = kw2 − w1 k2 − η 2 (2/ηh· · · i − k· · ·k )
                     2                                2                     2
≤ kw2 − w1 k2 − η 2 (Γh· · · i − k· · ·k ) ≤ kw2 − w1 k2                                                                              ///

Now, we provide the main proposition. We initialize the weight of local clients wkt,0 = wt , perform
E epochs on global and local clients and compare averaged weight on local clients wt+1 with weight
                 t+1
of global model w(GD) . The idea for proof is developed from [cite]. First, note that

                            t+1                   1 X t,E  t,E                           1 X t,E  t,E
                    wt+1 − w(GD)        2
                                            =        wk − w(GD)                      ≤      wk − w(GD)                  2
                                                                                                                            .
                                                  K                                      K
                                                      k                          2               k

Next, for each i ∈ {1, · · · , E},
                                       h                  i                       h                     i
    wkt,i − w(GD)
            2
             t,i
              = w t,i−1
                  k      − η∇ w EDk
                                   (t) LFedLSD (w
                                                  t,i−1
                                                  k     )   − w t,i−1
                                                                (GD)  + η∇ w E D   L  FedLSD (w t,i−1
                                                                                                (GD)  )
                                                                                                                                                  2
                  h                  i                        h                 i
   t,i−1                      t,i−1       t,i−1                          t,i−1
≤ wk     − η∇w ED LFedLSD (wk ) − w(GD) + η∇w ED LFedLSD (w(GD)                )
                                                                                    2
          h                i             h                    i
                     t,i−1                            t,i−1
+ η∇w ED LFedLSD (wk ) − η∇w ED(t) LFedLSD (wk )
                                                            k                                    2
                                      C                                                                      
                                      X                                                           ∂z t,i−1
≤ wkt,i−1 − w(GD)
             t,i−1
                                  + η     |  (t)                   t,i−1
                                            Pk (c) − P(c)| EDc (q(wk     ) − (βq g
                                                                                   + (1 − β)p y )    (w     )   .
                                2
                                      c=1
                                                                                                  ∂w k
                                                                                                  (Lemma 1)

The second term of the last equation can be bounded as:

    C                                                                     
    X                                                             ∂z t,i−1
η       |Pk(t) (c) − P(c)| EDc (q(wkt,i−1 ) − (βq g + (1 − β)py )    (wk )
    c=1
                                                                  ∂w                                           2
       C          (t)
             |   Pk (c)   − P(c)|
       X                            q
                                            (t)
≤η                 q                    Pk (c) EDc [· · · ]
                        (t)                                     2
       c=1             Pk (c)
   v                      v                                                                               v
   u C                    u C                                                                             u C
   uX |P (t) (c) − P(c)|2 uX                                                2                             uX
                                                                                                      (t) t
                                                                                                                                      2
        k                      (t)                                                                             (t)
≤η t
              (t)
                          t   Pk (c) EDc [· · · ]                               = ηχ(P               Pk )     Pk (c) EDc [· · · ]
          c=1
                      Pk (c)                  c=1
                                                                            2
                                                                                                            c=1
                                                                                                                                      2
                          r
                                                  2
                    (t)                                             (t)
≤ ηχ(P            Pk )   ED(t) [· · · ] = ηχ(P Pk ) ED(t) [· · · ]
                             k          2             k            2
                                                                         
                   (t)              t,i−1     g           ∂z       t,i−1
≤ ηχ(P            Pk ) ED(t) (q(wk ) − (βq + (1 − β)py )        (w       )
                         k                                ∂w k                                             2
                     (t)
≤ ηBχ(P             Pk )ED(t)      kq(wkt,i−1 )             g
                                          − (βq + (1 − β)py )k2                                                    (Jensen’s Inequality)
                          k

                     (t)
                    Pk ) q(·, wkt,i−1 ) − β q g + (1 − β) py
                                                             
≤ ηBχ(P                                                        2                          (t)
                                                                                                 .               (Hölder’s Inequality)
                                                                                 L       Dk

Thus,
                                                                           (t)
 wkt,i − w(GD)
          t,i
                          ≤ wkt,i−1 − w(GD)
                                       t,i−1
                                                                          Pk ) q(·, wkt,i−1 ) − β q g + (1 − β) py
                                                                                                                                
                     2                                2
                                                          +ηBχ(P                                                                            (t)
                                                                                                                                                  .
                                                                                                                                    L2 Dk

                                                                    17
Consequently, we get
         t+1              1 X t,E  t,E
 wt+1 − w(GD)     2
                      ≤      wk − w(GD)           2
                          K
                            k
        E
  1 X X  t,i      t,i            t,i−1     t,i−1
                                                    
=           wk − w(GD)   2
                           −   w k      − w (GD)  2
                                                                                       (wkt,0 = w(GD)
                                                                                                 t,0
                                                                                                      = wt )
  K
     k i=1
                                                                                     
                      E
  ηB X         (t)
                    X
                          q(·, wkt,i−1 ) − β q g + (1 − β) py
                                                              
≤        χ(P Pk )                                                             (t)
                                                                                      .
  K                  i=1
                                                                          L2 Dk
       k∈R(t)

C     Local features visualization
We further conduct an additional experiment on features of the trained local model. The experimental
setup is same with Section 2, but this time, we trained the global model for 100 communication
rounds on heterogeneous locals and distributed over 10 homogeneous locals and 10 heterogeneous
locals. In the homogeneous local case Figure 10a, the features are clustered by classes, regardless of
which local they are learned from. On the other hand, in the heterogeneous local case Figure 10b, the
features are clustered by which local distribution is learned. The same visualization, but colored by a
different view, is in Figure 11.

     (a) Homogeneous Locals (colored by classes)           (b) Heterogeneous Locals (colored by locals)
Figure 10: T-SNE visualization of features on CIFAR-10 test samples after local training on (a)
homogeneous local distributions and (b) heterogeneous local distributions. The T-SNE is conducted
together for the test sample features of global and 10 local models.

    (a) Homogeneous Locals (colored by classes)            (b) Heterogeneous Locals (colored by locals)
Figure 11: T-SNE visualization of feature region shifting on CIFAR-10 test samples after local
training on (a) Homogeneous local distributions and (b) heterogeneous local distributions

                                                      18
You can also read