diff --git a/books/fl/src/SUMMARY.md b/books/fl/src/SUMMARY.md index 0fdd43d..f1baa3f 100644 --- a/books/fl/src/SUMMARY.md +++ b/books/fl/src/SUMMARY.md @@ -20,10 +20,10 @@ - [Vanilla FL](horizontal/vanilla_fl/README.md) - [FedSGD](horizontal/vanilla_fl/fedsgd.md) - [FedAvg](horizontal/vanilla_fl/fedavg.md) - - [Robust Global FL]() <-- (horizontal/robust_global_fl/README.md) --> - - [FedAdam]() <-- (horizontal/robust_global_fl/fedadam.md) --> - - [FedProx]() <-- (horizontal/robust_global_fl/fedprox.md) --> - - [MOON]() <-- (horizontal/robust_global_fl/moon.md) --> + - [Robust Global FL](horizontal/robust_global_fl/README.md) + - [FedOpt](horizontal/robust_global_fl/fedopt.md) + - [FedProx](horizontal/robust_global_fl/fedprox.md) + - [MOON](horizontal/robust_global_fl/moon.md) - [Personalized FL]() <-- (horizontal/personalized/README.md) --> - [FedPer]() <-- (horizontal/personalized/fedper.md) --> - [FENDA-FL]() <-- (horizontal/personalized/fenda.md) --> diff --git a/books/fl/src/assets/FedProxAdaptation_bottom.png b/books/fl/src/assets/FedProxAdaptation_bottom.png new file mode 100644 index 0000000..b4e1ec6 Binary files /dev/null and b/books/fl/src/assets/FedProxAdaptation_bottom.png differ diff --git a/books/fl/src/assets/FedProxAdaptation_top.png b/books/fl/src/assets/FedProxAdaptation_top.png new file mode 100644 index 0000000..59d6e6e Binary files /dev/null and b/books/fl/src/assets/FedProxAdaptation_top.png differ diff --git a/books/fl/src/assets/SplitModels.svg b/books/fl/src/assets/SplitModels.svg new file mode 100644 index 0000000..d0d4588 --- /dev/null +++ b/books/fl/src/assets/SplitModels.svg @@ -0,0 +1,4 @@ + + + +
Feature Map
Input Features
Classifier
Label
Global Model
Feature Map
Input Features
Classifier
Label
Previous Local Model
Feature Map
Input Features
Classifier
Label
Current Local Model
diff --git a/books/fl/src/assets/algorithm-fedopt.svg b/books/fl/src/assets/algorithm-fedopt.svg new file mode 100644 index 0000000..938fc65 --- /dev/null +++ b/books/fl/src/assets/algorithm-fedopt.svg @@ -0,0 +1,1010 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/books/fl/src/assets/algorithm-fedprox.svg b/books/fl/src/assets/algorithm-fedprox.svg new file mode 100644 index 0000000..e1a8956 --- /dev/null +++ b/books/fl/src/assets/algorithm-fedprox.svg @@ -0,0 +1,661 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/books/fl/src/assets/algorithm-fedsgd.svg b/books/fl/src/assets/algorithm-fedsgd.svg new file mode 100644 index 0000000..6dbc801 --- /dev/null +++ b/books/fl/src/assets/algorithm-fedsgd.svg @@ -0,0 +1,447 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/books/fl/src/assets/algorithm-moon.svg b/books/fl/src/assets/algorithm-moon.svg new file mode 100644 index 0000000..5a8be95 --- /dev/null +++ b/books/fl/src/assets/algorithm-moon.svg @@ -0,0 +1,647 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/books/fl/src/assets/combined_loss_objective.svg b/books/fl/src/assets/combined_loss_objective.svg new file mode 100644 index 0000000..c427a50 --- /dev/null +++ b/books/fl/src/assets/combined_loss_objective.svg @@ -0,0 +1,223 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/books/fl/src/assets/fed_df_model.png b/books/fl/src/assets/fed_df_model.png new file mode 100644 index 0000000..5a397fc Binary files /dev/null and b/books/fl/src/assets/fed_df_model.png differ diff --git a/books/fl/src/assets/fedavg_drift.svg b/books/fl/src/assets/fedavg_drift.svg new file mode 100644 index 0000000..a275ebc --- /dev/null +++ b/books/fl/src/assets/fedavg_drift.svg @@ -0,0 +1,881 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/books/fl/src/assets/fedavg_model.png b/books/fl/src/assets/fedavg_model.png new file mode 100644 index 0000000..561c3c3 Binary files /dev/null and b/books/fl/src/assets/fedavg_model.png differ diff --git a/books/fl/src/assets/fedsgd_steps.svg b/books/fl/src/assets/fedsgd_steps.svg new file mode 100644 index 0000000..b36a495 --- /dev/null +++ b/books/fl/src/assets/fedsgd_steps.svg @@ -0,0 +1,994 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/books/fl/src/assets/heterogeneity_two_routes_alt.svg b/books/fl/src/assets/heterogeneity_two_routes_alt.svg new file mode 100644 index 0000000..ac5ccb4 --- /dev/null +++ b/books/fl/src/assets/heterogeneity_two_routes_alt.svg @@ -0,0 +1,4 @@ + + + +
Global FL with Robust Optimization
Personalized FL
Two Common
Routes
diff --git a/books/fl/src/assets/local_model_1.png b/books/fl/src/assets/local_model_1.png new file mode 100644 index 0000000..92b8bdc Binary files /dev/null and b/books/fl/src/assets/local_model_1.png differ diff --git a/books/fl/src/assets/local_model_2.png b/books/fl/src/assets/local_model_2.png new file mode 100644 index 0000000..5ae2f17 Binary files /dev/null and b/books/fl/src/assets/local_model_2.png differ diff --git a/books/fl/src/horizontal/README.md b/books/fl/src/horizontal/README.md index ada2d5f..e18172c 100644 --- a/books/fl/src/horizontal/README.md +++ b/books/fl/src/horizontal/README.md @@ -54,7 +54,7 @@ This section of the book is organized as follows: - [FedSGD](vanilla_fl/fedsgd.md) - [FedAvg](vanilla_fl/fedavg.md) - [Robust Global FL](robust_global_fl/index.md) - - [FedAdam](robust_global_fl/fedadam.md) + - [FedOpt](robust_global_fl/fedopt.md) - [FedProx](robust_global_fl/fedprox.md) - [MOON](robust_global_fl/moon.md) - [Personalized FL](personalized/index.md) diff --git a/books/fl/src/horizontal/robust_global_fl/README.md b/books/fl/src/horizontal/robust_global_fl/README.md index 64427d4..5e05feb 100644 --- a/books/fl/src/horizontal/robust_global_fl/README.md +++ b/books/fl/src/horizontal/robust_global_fl/README.md @@ -1 +1,176 @@ + + # Robust Global FL Approaches + +{{ #aipr_header }} + +## Data heterogeneity in standard ML + +In standard ML, when training and deploying a model, a standard underlying +assumption is that the training data is distributionally +similar to new data to which the model will be applied. There are methods +that specialize in out-of-domain generalization, but in most cases +models are assumed to be applied on data that is drawn from the same +statistical distributions that describe the data on which it was trained. The +validity of this assumption can degrade, for example, over time or due to +the model being used to make predictions in entirely new domains. + +While data shifts present a significant challenge in centralized ML training, +the characteristics that describe data shifts in this domain also exist in FL +when comparing disparate, distributed datasets. Data shift between such +datasets is typically referred to as "data heterogeneity" between clients. Such +heterogeneity introduces new obstacles in FL and is quite prevalent. Before +discussing its impact on federated training and how it is addressed. Let's +define some types of data divergence. Three common ways to describe +disparities or shifts between training and inference data are:[^1] + +1. [Label Shift](#label-shift) +2. [Covariate Shift](#covariate-shift) +3. [Concept Drift](#concept-drift) + +Let \\(X\\) and \\(Y\\) represent the feature (input) and label (output) +spaces, respectively for a model. Shifts are present, regardless of whether +model performance degrades, when the joint distributions + +$$ +\begin{align} +\\mathbb{P}\_{\\text{train}}(X, Y) \\neq \\mathbb{P}\_{\\text{test}}(X, Y). \tag{1} +\end{align} +$$ + +### Label Shift + +Label shifts occur when there is a change in the label distribution \\(\\mathbb{P}(Y)\\) +with a fixed posterior distribution \\(\\mathbb{P}(X \\vert Y)\\). That is, the +probability of seeing different label values shifts, but the distribution of +features conditioned on the labels does not change. A pertinent example of +this might be data meant to train a model to diagnose COVID-19 in the early +days of spread versus the later stages when the virus was widely circulating. +Generally, the symptoms, given that someone had the virus, did not markedly +change. However, the prevalence of the virus, \\(\\mathbb{P}(Y)\\), did. + +### Covariate Shift + +Covariate shifts between data distributions represent a change in the feature +distribution, \\(\\mathbb{P}(X)\\), while the statistical relationship of labels to +features, \\(\\mathbb{P}(Y \\vert X)\\), remains fixed. Consider the setting of training +a readmission risk model on data drawn from the patient population of a +general hospital. If, for instance, that model were transferred for use at a +nearby pediatric hospital, assuming all else equal, predictions from that model +would be influenced by covariate drift due to the change in patient +demographics. Namely, though features associated with younger patients are +likely part of the general hospital population, they will, of course, be +statistically over-represented in the data points seen by the model at the +pediatric hospital. + +### Concept Drift + +Concept drift is characterized by a change in \\(\\mathbb{P}(Y \vert X)\\) provided a +fixed \\(\\mathbb{P}(Y)\\). Essentially, this drift encapsulates a shift in the +predictive relationship between the features, \\(X\\), and the labels, \\(Y\\). +As an illustrative example, consider training a purchase conversion model +for airline ticket purchases where two possible incentives are features. The +first offers a ticket discount to encourage purchase, whereas the second offers +free add-ons. In good economic periods, the second incentive may produce higher +conversion rates. On the other hand, in periods of economic uncertainty, +perhaps the first offer would do so. + +Note that each of the shifts discussed above may exist in isolation or be +present together to varying degrees. + +## How does data heterogeneity manifest in FL? + +In FL, differences in training data distributions are not strictly temporal or +marked by a change in the joint probability distributions of the training and +test datasets, as expressed in Equation (1). Each client participating in +federated training might naturally exhibit distribution disparities compared +to one another. Consider the example given in the Section on +[Covariate Shift](#covariate-shift). If the general and pediatric hospitals +would like to collaboratively train a model using FL, the demographics of +their patient populations mean that there will be substantial statistical +heterogeneity between their respective training datasets. + +Each distributed training dataset in an FL system may naturally exhibit the +various disparities, compared with one another, discussed above. As a further +example, consider two financial institutions working together to train a fraud +detection model. Because of their different clientele, one bank may experience +fraud at a rate of 2% per transaction, while the other may see only 0.1%, an +example of label shift, among potentially others. + +## How does it impact FL models and their training? + +Data heterogeneity, in its various forms, has been linked to a number of +challenges in training FL models using methods like +[FedAvg](../vanilla_fl/fedavg.md), including slower convergence, performance +degradation, and unevenly distributed training dynamics among clients. In [2], +a clear illustration of the impact of data heterogeneity is provided. In the +figures below, two clients have locally trained a model on their respective +datasets. + +
+
+Local Model 1 +Local Model 2 +
Two clients with different datasets. Note that each holds a +slightly different view of the feature space. Notably, Client 1 (left) has a +distinct cluster of data points in the bottom right and fewer points labeled in +green within the red cluster.
+
+
+ +The decision boundaries of the locally trained models are largely similar but +differ in important ways. If the two models are averaged via FedAvg (see figure +below), the result is a blurred decision boundary which has diverged from the +sharp boundary one would expect to compute were the data agglomerated and a +central model trained. Alternatively, using an approach that is more robust +to data heterogeneity, FedDF,[^2] the resulting model exhibits the kinds of +classification boundaries one would expect when considering the data +distributions from a global perspective. + +
+
+FedAvg Model +FedDF Model +
Model resulting from FedAvg (left) compared with the model +trained using FedDF (right).
+
+
+ +There are two common routes, among many other routes, for addressing +heterogeneity in FL. The first is to maintain a sense of a single global model +to be trained by all participants. Modifications to items like the aggregation +strategy, local learning objectives, or corrections to model updates are +applied to better align FL training with the dynamics of centralized training +without sacrificing most of the benefits associated with the original FedAvg +algorithm. The second route is to abandon, to one degree or another, the idea +of a global model that performs well across all clients and instead allow +each client to train a unique model. This is known as Personal or Personalized +FL (pFL). Such models still benefit from global information through aspects of +FL, but more strongly emphasize local distributions. + +
+
+Two FL Routes +
Two possible routes for addressing data heterogeneity in FL.
+
+
+ +In the subsequent sections of this chapter, we'll cover a few of the many FL +methods aimed at robust global model optimization in FL. Such models are often +more generalizable and are more easily distributed to new domains than their +pFL equivalents. Alternatively, model performance on each client may not be +as high as those produced by pFL approaches. + +#### References & Useful Links + +[^1]: + J. Quinonero-Candela, M. Sugiyama, A. Schwaighofer, and N. D. Lawrence. Dataset shift + in machine learning. Mit Press, 2008 + +[^2]: + [Lin, Tao et al. “Ensemble distillation for robust model fusion in federated + learning”. In: Proceedings of the 34th International Conference on Neural + Information Processing Systems. NIPS ’20. Vancouver, BC, Canada: Curran + Associates Inc., 2020.](https://proceedings.neurips.cc/paper/2020/file/18df51b97ccd68128e994804f3eccc87-Paper.pdf) + +{{#author emersodb}} diff --git a/books/fl/src/horizontal/robust_global_fl/fedadam.md b/books/fl/src/horizontal/robust_global_fl/fedadam.md deleted file mode 100644 index 166cea9..0000000 --- a/books/fl/src/horizontal/robust_global_fl/fedadam.md +++ /dev/null @@ -1 +0,0 @@ -# FedAdam diff --git a/books/fl/src/horizontal/robust_global_fl/fedopt.md b/books/fl/src/horizontal/robust_global_fl/fedopt.md new file mode 100644 index 0000000..109f77e --- /dev/null +++ b/books/fl/src/horizontal/robust_global_fl/fedopt.md @@ -0,0 +1,130 @@ + + +# The FedOpt Family of Aggregation Strategies + +{{ #aipr_header }} + +Recall that modern deep learning optimizers like AdamW[^1] or AdaGrad[^2] use +first- and second-order moment estimates of the stochastic gradients computed +during iterative optimization to adaptively modify the model updates. +At a high level, each algorithm aims to reinforce common update directions +(i.e. those with momentum) and damp update elements corresponding to noisy +directions (i.e. those with high batch-to-batch variance). The FedOpt +family[^3] of algorithms, considers modifying the traditional +[FedAvg](../vanilla_fl/fedavg.md) aggregation algorithm to incorporate +similar adaptations into server-side model updates in FL. + +## Mathematical motivation + +In FedAvg, recall that, after a round of local training on each client, +client model weights are combined into a single model representation via + +$$ +\begin{align*} +\\mathbf{w}\_{t+1} = \\sum\_{k \\in C_t} \\frac{n_k}{n_s} \\mathbf{w}^k_{t+1}, +\end{align*} +$$ + +where \\(\\mathbf{w}^k\_{t+1}\\) is simply the model weights after local +training on client \\(k\\). For round \\(t\\), each client starts local +training from the same set of weights, \\(\\mathbf{w_t}\\). Assume that each +client has the same number of data points such that \\(n_k = m\\). With a bit +of algebra, the update is rewritten + +$$ +\begin{align} +\\mathbf{w}\_{t+1} = \\sum\_{k \\in C_t} \\frac{n_k}{n_s} \\mathbf{w}^k\_{t+1} &= \\mathbf{w}_t - \\frac{1}{C_t} \\sum\_{k \\in C_t} +\\left( \\mathbf{w}_t - \\mathbf{w}^k\_{t+1} \\right), \\\\ +&= \\mathbf{w}_t + \\frac{1}{C\_t} \\sum\_{k \\in C_t} \\Delta^k\_{t+1}, \\\\ +&= \\mathbf{w}_t + \\Delta\_{t+1}. \tag{1} +\end{align} +$$ + +Here, \\(\\Delta^k\_{t+1} = \\mathbf{w}^k\_{t+1} - \\mathbf{w}\_t\\) is just +the vector pointing from the initial models weights to those after local +training and \\(\\Delta\_{t+1}\\) is simply the uniform average of these +update vectors. + +Recall that, if each client uses a fixed learning rate, \\(\eta\\), and +performs a single, full gradient update, FedAvg is equivalent to centralized +large-batch SGD. Similarly, in this case, if each client performs one step of +batch SGD with a learning rate of 1.0, then the update in Equation (1) is +equivalent to a batch-SGD update with a learning rate of 1.0 for the +**server**. The "server-side" batch is the union of the batches used on each +client. + +The observation that \\(-\Delta\_{t+1}\\) is simply a stochastic gradient +motivates treating these update directions like the stochastic gradients +in standard adaptive optimizers. It's important to note that if the clients, +for instance, apply multiple steps of local SGD or use different learning +rates, the exact equivalence of \\(-\Delta\_{t+1}\\) to a stochastic gradient +is broken. However, it shares similarities to such a gradient and is, +therefore, called a "pseudo-gradient."[^3] + +## The algorithms: FedAdagrad, FedAdam, FedYogi + +Drawing inspiration from three successful, traditional adaptive optimizers, +the adaptive server-side aggregation strategies of FedAdaGrad, FedAdam, and +FedYogi have been proposed. See the algorithm below for details. + +
+
+FedOpt Algorithms +
+
+ +Those familiar with the mathematical formulations of Adagrad, Adam,[^4] and +Yogi[^5] will recognize the general structure of these equations. Computation +of \\(m_t\\), based on the average of the update directions suggested by each +client through local training (\\(\Delta\_{t+1}\\)) serves to accumulate +momentum associated with directions that are consistently and frequently part +of these updates. On the other hand, \\(\nu_t\\) estimates the variance +associated with update directions throughout the server rounds. Directions +with higher variance values are damped in favor of those with more consistency +round over round. + +As with the usual forms of these algorithms, there are a number of +hyper-parameters that can be tuned, including \\(\tau, \beta_1,\\) and +\\(\beta_2\\). However, sensible defaults are suggested in the paper such that +\\(\beta_1=0.9\\) and \\(\beta_2=0.99\\). The authors also show that +performance is generally robust to \\(\tau\\). + +A number of experiments show that the proposed FedOpt family of algorithms +can outperform FedAvg, especially in heterogeneous settings. Moreover, these +algorithms, in the experiments of the paper, outperform SCAFFOLD,[^6] a +variance reduction method aimed at improving convergence in the presence of +heterogeneity. A final advantage of the FedOpt family of algorithms is that +they are accompanied by several convergence results showing that, as long as +the variance of the local gradients is not too large, the algorithms converge +properly. + +#### References & Useful Links + +[^1]: + [I. Loshchilov and F. Hutter. Fixing weight decay regularization in ADAM, + 2018.386](https://arxiv.org/pdf/1711.05101) + +[^2]: + [Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive subgradient methods + for online learning and stochastic optimization. Journal of machine learning research, 12(7).](https://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) + +[^3]: + [S. J. Reddi, Z. Charles, M. Zaheer, Z. Garrett, K. Rush, J. Konêcný, S. Kumar, and H. B. + McMahan. Adaptive federated optimization. In ICLR 2021, 2021.](https://arxiv.org/abs/2003.002950) + +[^4]: + [Kingma, D. P. & Ba, J. (2015). Adam: A Method for Stochastic Optimization. + In Y. Bengio & Y. LeCun (eds.), ICLR (Poster).](https://arxiv.org/pdf/1412.6980) + +[^5]: + [Manzil Zaheer, Sashank J. Reddi, Devendra Sachan, Satyen Kale, and + Sanjiv Kumar. 2018. Adaptive methods for nonconvex optimization. In Proceedings of the 32nd International Conference on Neural Information Processing Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 9815–9825.](https://proceedings.neurips.cc/paper_files/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf) + +[^6]: + [S. P. Karimireddy, S. Kale, M. Mohri, S. Reddi, S. Stich, and A. T. Suresh. + SCAFFOLD: Stochastic controlled averaging for federated learning. In + Hal Daumé III and Aarti Singh, editors, Proceedings of the 37th + International Conference on Machine Learning, volume 119 of Proceedings of + Machine Learning Research, pages 5132–5143. PMLR, 13–18 Jul 2020.](https://www.microsoft.com/en-us/research/publication/frustratingly-easy-neural-domain-adaptation/) + +{{#author emersodb}} diff --git a/books/fl/src/horizontal/robust_global_fl/fedprox.md b/books/fl/src/horizontal/robust_global_fl/fedprox.md index 6f4e40e..bc783d8 100644 --- a/books/fl/src/horizontal/robust_global_fl/fedprox.md +++ b/books/fl/src/horizontal/robust_global_fl/fedprox.md @@ -1 +1,161 @@ + + # FedProx + +{{ #aipr_header }} + +The FedProx algorithm[^1] is one of the earliest approaches specifically aimed +at addressing the optimization challenges associated with data heterogeneity in +FL. At its core, the FedProx algorithm is quite straightforward. However, +prior to diving into the modifications proposed in the FedProx approach, we'll +first consider the kind of phenomenon that FedProx, along with other methods, +attempts to counteract. + +To help illustrate the issue, we'll use some helpful visualizations +from researchers who proposed the SCAFFOLD method.[^2],[^3] +Consider a two-client FL setting. Each client has their own loss landscape +based on their privately held data, denoted \\(f_1\\) and \\(f_2\\). If each +client has an equal amount of data, the global loss surface, which is the loss +function when constructed from all data available on both clients is equivalent +to \\((f_1 + f_2)/2\\). When performing standard federated training, the +objective is to find model weights corresponding to the minimum of this global +loss function. See the figure below. Note that the minima associated with the +client loss functions are distinct from the global minimum. + +
+
+Combined loss objective +
Comparison of local loss landscapes for two clients with the combined global loss.
+
+
+ +Recall that optimization with [FedSGD](../vanilla_fl/fedsgd.md) is equivalent +to centralized large-batch SGD. That is, rounds of FedSGD are equivalent to +optimizing the global loss function, expressed here as \\((f_1 + f_2)/2\\). As +such, with a properly tuned learning rate, FedSGD will converge to the +global optimum, as illustrated in the figure below. Each averaged gradient +step makes steady progress towards the global minimum. + +
+
+FedSGD and global convergence +
FedSGD rounds result in averaged models making steady progress towards +the global minimum.
+
+
+ +As detailed in the chapter on [FedAvg](../vanilla_fl/fedavg.md),[^4] there is a +substantial reduction in communication overhead if each client applies +multiple steps of batch SGD, optimizing the local model based on the local +loss. It was noted therein, however, that this breaks the equivalence enjoyed, +for example, by FedSGD with centralized large-batch SGD. In settings, such as +the one illustrated in the figures thus far, with data heterogeneity and +markedly different loss landscapes this can engender various issues. One such +issue is often referred to as "client drift" and is illustrated in the figure +below. + +
+
+FedAvg and the influence of local drift +
Illustration of "client drift" in FedAvg updates caused by +differences in the shape of the local loss functions of each client.
+
+
+ +In the figure, each client is applying three local steps of batch SGD before +sending the resulting weights to server for aggregation. The grey dots +represent the updates using FedSGD for three rounds from the previous figure. +The update using FedAvg deviates quite a bit from this path with a distinct +drift towards the minima of Client 2. Drifts of this kind can be induced by the +shape of the local loss surface and cause issues with FedAvg, such as slowed +convergence. + +## The Math + +The general idea for FedProx is to limit models from drifting too far during +local training. For a server round \\(t\\), consider the aggregated weights +\\(\\mathbf{w}\_t\\). For a given client \\(k\\), let +\\(\\ell_k(b; \\mathbf{w})\\) denote the local loss function for a batch, +\\(b\\), of data, parameterized by model weights \\(\\mathbf{w}\\). The primary +modification of FedAvg in the FedProx algorithm is to augment +\\(\\ell_k(b; \\mathbf{w})\\) with a penalty term such that, for \\(\mu > 0\\), +the local loss becomes + +$$ +\begin{align*} +\\ell_k(b; \\mathbf{w}) + \\mu \\Vert \\mathbf{w} - \\mathbf{w_t} \\Vert^2. +\end{align*} +$$ + +The penalty term is referred to as the proximal loss. It penalizes significant +deviation from the global model during local training such that loss +optimization must trade off improvements in the standard loss with potential +divergence from the original model weights. Revisiting the loss surfaces above, +the FedProx penalty term alters the loss surface to make client drift less +attractive, unless it leads to significant performance gains. + +## The Algorithm + +The FedProx algorithm is very similar to that of FedAvg, with the only +modification coming in the local update calculations. + +
+
+FedProx Algorithms +
+
+ +## Adapting \\(\\mu\\) + +For a well-tuned \\(\\mu\\), FedProx has been shown to outperform FedAvg under +heterogeneous data conditions. It is widely applied, both because it is a +simple modification to the FedAvg framework and because it works fairly well +across a number of tasks. However, in settings where the data is homogeneous, +FedProx has been shown to under-perform compared to FedAvg when \\(\\mu>0\\). +See the top left of the figure below. + +Because of this, the authors of FedProx offer an alternative to extensive +hyper-parameter tuning. Heuristically, the proximal weight may be adapted +across server rounds. If the aggregated server-side training loss +(average final loss on each client) fails to decrease for a round, \\(\\mu\\) +is increased. If the loss improves for some number of rounds, \\(\\mu\\) is +decreased. In the figure below, this procedure results in the fuchsia colored +line in the Figures below. + +
+
+FedProx vs. FedAvg +FedProx vs. FedAvg +
Comparison of FedProx to FedAvg in various settings. On the top left, +data is homogeneous across clients. Without adaptation FedProx struggles to out +perform FedAvg. Data is heterogeneous in the other settings and FedProx performs +well with and without adaptation. +
+
+
+ +#### References & Useful Links + +[^1]: + [T. Li, A. K. Sahu, M. Zaheer, M. Sanjabi, A. Talwalkar, and V. Smith. + Federated optimization in heterogeneous networks. In I. Dhillon, + D. Papailiopoulos, and V. Sze, editors, Proceedings of Machine Learning and + Systems, volume 2, pages 429–450, 2020.](https://arxiv.org/pdf/1812.06127) + +[^2]: + [S. P. Karimireddy, S. Kale, M. Mohri, S. Reddi, S. Stich, and A. T. Suresh. + SCAFFOLD: Stochastic controlled averaging for federated learning. In + Hal Daumé III and Aarti Singh, editors, Proceedings of the 37th + International Conference on Machine Learning, volume 119 of Proceedings of + Machine Learning Research, pages 5132–5143. PMLR, 13–18 Jul 2020.](https://www.microsoft.com/en-us/research/publication/frustratingly-easy-neural-domain-adaptation/) + +[^3]: + [Images adapted from Talk on "Stochastic Controlled Averaging for + Federated Learning"](https://docs.google.com/presentation/d/1SYhRC6NMEMJJL2FTGDu5DJeINJBkbwFVGd2Emy43fk0/edit) + +[^4]: + [H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas. + Communication-efficient learning of deep networks from decentralized data. + Proceedings of the 20th AISTATS, 2017.](https://proceedings.mlr.press/v54/mcmahan17a/mcmahan17a.pdf) + +{{#author emersodb}} diff --git a/books/fl/src/horizontal/robust_global_fl/moon.md b/books/fl/src/horizontal/robust_global_fl/moon.md index e3ceccf..9399740 100644 --- a/books/fl/src/horizontal/robust_global_fl/moon.md +++ b/books/fl/src/horizontal/robust_global_fl/moon.md @@ -1 +1,151 @@ -# MOON + + +# MOON: Model-Contrastive Federated Learning + +{{ #aipr_header }} + +The MOON algorithm[^1] is built on the same principles as the +[FedProx](./fedprox.md)[^2] approach. That is, it targets limiting +client-specific drift during local training by constraining how heavily local +model updates stray from global models. The fundamental difference is the +way that drift is measured to construct the penalty function. + +## Contrastive Loss: A Brief Interlude + +Before defining the MOON penalty, we need to review contrastive loss and what +it aims to do in general. Say we have three vectors, \\(\\mathbf{z}\\), +\\(\\mathbf{z}\_s\\), \\(\\mathbf{z}\_d \in \mathbb{R}^n\\) and we want to +optimize a model, parameterized by \\(\\mathbf{w}\\), to map from an input, +\\(\\mathbf{x}\\), to a representation, \\(\\mathbf{z}\\), which is **closer** +to \\(\\mathbf{z}\_s\\) and **further** from \\(\\mathbf{z}\_d\\). We define + +$$ +\begin{align*} +\\ell\_{\\text{con}}(\mathbf{x};\\mathbf{w}) = - \\log \\frac{\exp +\\left(\\text{sim}(\\mathbf{z}, \\mathbf{z}\_s) \\tau^{-1} \\right)}{\\exp +\\left(\\text{sim}(\\mathbf{z}, \\mathbf{z}\_s) \\tau^{-1} \\right) + \\exp +\\left(\\text{sim}(\\mathbf{z}, \\mathbf{z}\_d) \\tau^{-1} \\right)}, +\end{align*} +$$ + +where \\(\\text{sim}(\\cdot, \\cdot)\\) is cosine-similarity and +\\(\tau > 0\\) is a temperature. Increasing the similarity of +\\(\\mathbf{z}\\) to \\(\\mathbf{z}\_s\\) **increases** the numerator. Pushing +\\(\\mathbf{z}\\) away from \\(\\mathbf{z}\_d\\) decreases the denominator. +Each of which makes \\(\\ell\_{\\text{con}}(\mathbf{x};\\mathbf{w})\\) +**smaller**. + +Contrastive loss objectives have been widely used to bring latent +representations of similar inputs closer together and push dissimilar inputs +further apart. For example, contrastive learning is used extensively in +CLIP[^3] as a means of pushing image-caption text pairs closer together, while +pushing unrelated image-text pairs further apart in the CLIP model's +representation space. + +## MOON models and an alternative to weight-drift penalties + +As in FedProx, the major contribution of the MOON algorithm is to modify the +local learning objective for each client. As foreshadowed in the previous +section, the modification involves a contrastive loss function. To define this +loss function, MOON first considers splitting the model to be federally trained +into two stages: a feature map followed by a classification module. +Furthermore, at each server round, \\(t\\), it distinguishes between three +models. These models are illustrated in the figure below. + +
+
+Moon models and their latent representation +
The three models important in the computation of MOON's +contrastive loss functions and their latent representations.
+
+
+ +For server round \\(t\\), the model on the left represents the model after +weight aggregation by the server. In the middle is the final model after +local training on Client \\(i\\). The weights, \\(\\mathbf{w}^{t-1}\_i\\), have +been aggregated across participating clients to form \\(\\mathbf{w}^t\\). +Finally, the model on the right is the one being locally trained on Client +\\(i\\). The output of the feature maps in these models will be used to form +the local contrastive loss for Client \\(i\\). + +Let \\(\\ell_i((\\mathbf{x}, y); \\mathbf{w})\\) denote the local loss function +of Client \\(i\\) for an input, \\(\mathbf{x}\\), and label, \\(y\\), +parameterized by model weights \\(\\mathbf{w}\\). MOON augments this local loss +with the contrastive loss function + +$$ +\begin{align*} +\\ell\_{i, \\text{con}}((\\mathbf{x}, y);\\mathbf{w}^t\_i) = - \\log \\frac{\exp +\\left(\\text{sim}(\\mathbf{z}, \\mathbf{z}\_{\\text{glob}}) \\tau^{-1} +\\right)}{\\exp \\left(\\text{sim}(\\mathbf{z}, \\mathbf{z}\_{\\text{glob}}) +\\tau^{-1} \\right) + \\exp \\left(\\text{sim}(\\mathbf{z}, +\\mathbf{z}\_{\\text{prev}}) \\tau^{-1} \\right)}. +\end{align*} +$$ + +That is, the local objective for Client \\(i\\) is written + +$$ +\begin{align*} +\\ell_i((\\mathbf{x}, y); \\mathbf{w}^t\_i) + +\mu \\ell\_{i, \\text{con}}((\\mathbf{x}, y);\\mathbf{w}^t\_i), +\end{align*} +$$ + +for \\(\mu > 0\\). Note that, when computed over a batch of data, the average +loss over all data points in the batch is computed. + +The idea here is similar to FedProx. In some sense, we +still want to make sure that, during training, the model does not drift too +far from the global model, as was the case in FedProx. The difference, here, +is that we're applying that constraint in the feature representation space, +rather than directly in the model weights themselves. In the original work, +MOON showed notable improvements over methods like FedProx in heterogeneous +settings. However, it does not **always** outperform FedProx or even FedAvg.[^4] +As such, there are likely scenarios where MOON is the right approach for FL in +heterogeneous settings, while others might benefit from an alternative +technique. + +## The Algorithm + +The MOON algorithm is fairly similar to that of FedProx. Most server-side +aggregation strategies may be applied in combination with MOON. However, the +algorithm has some additional compute and memory overhead, as forward passes +of three separate models must be run in order to extract the latent +representations of the data points in each training batch. In the algorithm +below, FedAvg is used as the server-side strategy. + +
+
+Moon Algorithms +
+
+ +#### References & Useful Links + +[^1]: + [Q. Li, B. He, and D. Song. Model-contrastive federated learning. In Proceedings of the + IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2021a.](https://openaccess.thecvf.com/content/CVPR2021/papers/Li_Model-Contrastive_Federated_Learning_CVPR_2021_paper.pdf) + +[^2]: + [T. Li, A. K. Sahu, M. Zaheer, M. Sanjabi, A. Talwalkar, and V. Smith. + Federated optimization in heterogeneous networks. In I. Dhillon, + D. Papailiopoulos, and V. Sze, editors, Proceedings of Machine Learning and + Systems, volume 2, pages 429–450, 2020.](https://arxiv.org/pdf/1812.06127) + +[^3]: + [Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, + Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, et al. Learning transferable visual + models from natural language supervision. In International conference on machine learning, + pages 8748–8763. PMLR, 2021.](https://proceedings.mlr.press/v139/radford21a/radford21a.pdf) + +[^4]: + [Fatemeh Tavakoli, D. B. Emerson, Sana Ayromlou, John Taylor Jewell, + Amrit Krishnan, Yuchong Zhang, Amol Verma, and Fahad Razak. + A comprehensive view of personalized federated learning on heterogeneous + clinical datasets. In Kaivalya Deshpande, Madalina Fiterau, Shalmali Joshi, + Zachary Lipton, Rajesh Ranganath, and Iñigo Urteaga, editors, Proceedings + of the 9th Machine Learning for Healthcare Conference, volume 252 of + Proceedings of Machine Learning Research. PMLR, 16–17 Aug 2024.](https://proceedings.mlr.press/v252/tavakoli24a.html) + +{{#author emersodb}} diff --git a/books/fl/src/horizontal/vanilla_fl/fedavg.md b/books/fl/src/horizontal/vanilla_fl/fedavg.md index 5b6fa56..440fb0d 100644 --- a/books/fl/src/horizontal/vanilla_fl/fedavg.md +++ b/books/fl/src/horizontal/vanilla_fl/fedavg.md @@ -92,7 +92,7 @@ model as described by the weights \\(\mathbf{w}\_T\\).
-FedAvg Algorithm +FedAvg Algorithm
diff --git a/books/fl/src/horizontal/vanilla_fl/fedsgd.md b/books/fl/src/horizontal/vanilla_fl/fedsgd.md index d740092..c33bb55 100644 --- a/books/fl/src/horizontal/vanilla_fl/fedsgd.md +++ b/books/fl/src/horizontal/vanilla_fl/fedsgd.md @@ -118,7 +118,7 @@ client receives the final model as described by the weights \\(\mathbf{w}\_T\\).
-FedSGD Algorithm +FedSGD Algorithm