
I spent way too much time during my bachelor thesis diving into federated learning, and it turned out to be way cooler than I expected. The basic idea is pretty simple: what if we could train machine learning models without actually collecting everyone’s data in one place? Turns out you can, and there are some clever ways to make it secure too. Here’s what I learned while building my own implementation and writing my thesis.
What’s Federated Learning?
Usually when you train a model, you grab data from everywhere, dump it in one database, and train on that. Federated learning flips this around. Instead of moving data to the model, you move the model to where the data lives. Each device keeps its own data and just sends back what it learned.
The math looks like this: normally you’d minimize some loss function across all your data:
$$\min_{w\in\mathbb{R}^d} F(w) = \frac{1}{N}\sum_{i=1}^N \ell(x_i, y_i; w)$$
But with federated learning, you split this across different clients:
$$F(w) = \sum_{k=1}^K \frac{n_k}{n_{total}} F_k(w)$$
where each client $k$ has its own local loss $F_k$ with $n_k$ data points.
The process works like this:
- Send the current model to all clients
- Each client trains on their local data for a bit
- Clients send back their updates (not their data)
- Server combines all the updates into a new global model
Pretty neat way to keep data private while still getting the benefits of training on lots of data.
Basic Aggregation Methods
FedSGD
The simplest approach is FedSGD. Everyone does one gradient step and sends their gradient back:
$$w^{t+1} = w^t - \eta \sum_{k=1}^K \frac{n_k}{n_{total}} \nabla F_k(w^t)$$
This works but requires a lot of communication since you’re sending gradients after every single step.
FedAvg
FedAvg is way more practical. Let each client train for several rounds locally, then just average their models:
$$w^{t+1} = \sum_{k=1}^K \frac{n_k}{n_{total}} w_k^t$$
This cuts down communication dramatically and usually works just as well, though it can struggle when different clients have very different data.
Dealing with Bad Actors
One problem with federated learning is that some clients might send garbage updates, either by accident or on purpose. The simple averaging approach breaks down when you have outliers.
Geometric Median Aggregation
Instead of taking the arithmetic mean, you can use the geometric median, which is much more robust to outliers. You’re trying to find the point that minimizes the sum of distances to all client updates:
$$\min_z \sum_{k=1}^m \alpha_k ||w_k - z||$$
You solve this iteratively using something like the Weiszfeld algorithm:
$$z^{(i+1)} = \frac{\sum_{k=1}^m \beta_k^{(i)} w_k}{\sum_{k=1}^m \beta_k^{(i)}}, \text{ where } \beta_k^{(i)} = \frac{\alpha_k}{\max{\nu, ||w_k - z^{(i)}||}}$$
The math automatically gives less weight to updates that are far from the center, which helps filter out malicious or buggy clients.
Adding Differential Privacy
Even if clients only send gradients, a sneaky server might still be able to figure out things about the training data by analyzing those gradients carefully. Differential privacy fixes this by adding carefully calibrated noise.
The Core Idea
Differential privacy says that if you change one person’s data in the dataset, the output shouldn’t change much. Formally, a mechanism is $(\epsilon,\delta)$ differentially private if:
$$\Pr(\mathcal{M}(D) \in S) \leq e^\epsilon \Pr(\mathcal{M}(D’) \in S) + \delta$$
for any two datasets $D$ and $D’$ that differ by one record.
Making Federated Learning Private
In DP FedAvg, each client does two things before sending their gradient:
Clip the gradient to a maximum norm:
$$\tilde{g}_k = \frac{g_k}{\max(1, \frac{||g_k||_2}{C})}$$Add Gaussian noise:
$$\hat{g}_k = \tilde{g}_k + \mathcal{N}(0, \sigma^2C^2\mathbf{I})$$
Here’s how you’d implement this:
1 | import numpy as np |
Now even if someone intercepts the gradients, they’re seeing a noisy version that doesn’t reveal much about individual data points.
Homomorphic Encryption
Differential privacy limits what you can infer from gradients, but homomorphic encryption goes further: the server never sees the actual gradients at all, only encrypted versions.
How It Works
Homomorphic encryption lets you do math on encrypted data. If you have encrypted values $Enc(a)$ and $Enc(b)$, you can compute:
$$Enc(a) \oplus Enc(b) = Enc(a + b)$$
without ever decrypting them.
For federated learning:
- Each client encrypts their update: $c_k = Enc(w_k)$
- Server adds the encrypted updates: $c_{sum} = \sum_{k=1}^K c_k$
- Someone with the private key decrypts the sum: $Dec(c_{sum}) = \sum_{k=1}^K w_k$
Here’s the basic idea in code:
1 | from typing import Any, Protocol |
Putting It All Together
Here’s how a complete federated learning round might look with differential privacy:
1 | import numpy as np |
Real World Challenges
A few things I learned while implementing this stuff:
Non IID data is the biggest pain point. In the real world, different clients have totally different data distributions. Your phone’s photos look nothing like mine, which breaks a lot of the mathematical assumptions.
Client dropout happens constantly. Phones go offline, people close apps, connections fail. Your aggregation strategy needs to handle partial participation gracefully.
Privacy vs accuracy tradeoffs are real. Adding noise helps privacy but hurts model performance. Encrypting everything adds computational overhead. You’re constantly balancing security against practicality.
What I Learned
Building this federated learning system taught me that the theory is actually pretty straightforward, but the engineering challenges are where things get tricky. The math for differential privacy and homomorphic encryption looks intimidating, but the core ideas are simple once you get past the notation.
The most interesting part was seeing how all these techniques can work together. You can combine robust aggregation with differential privacy and homomorphic encryption to create systems that are resilient to both malicious attacks and curious servers.
If you want to dig deeper, check out my thesis and the code. The thesis goes into more detail about the performance tradeoffs and system design choices that matter in practice.