Federated Learning

I spent a big chunk of my bachelor thesis on federated learning, and it ended up being more interesting than I expected going in. The basic idea is simple: what if you could train a machine learning model without ever collecting everyone’s data in one place? It turns out you can, and there are some clever techniques for keeping the protocol robust against bad actors and adding cryptographic privacy guarantees on top. I wrote up what I learned while building my own implementation and writing my thesis.

What federated learning is

In ordinary supervised learning, you collect data from everyone, dump it in one database, and train on the combined dataset. Federated learning inverts that. Instead of moving data to the model, you move the model to where the data lives. Each device trains locally and sends back only what it learned, never the underlying data.

The objective decomposes nicely. A standard centralized loss is:

\[\min_{w \in \mathbb{R}^d} F(w) = \frac{1}{N} \sum_{i=1}^N \ell(x_{i}, y_{i}; w)\]

In federated learning the same loss splits across clients:

\[F(w) = \sum_{k=1}^K \frac{n_{k}}{n_{\text{total}}} F_{k}(w)\]

where each client \(k\) holds its own local loss \(F_k\) over \(n_k\) data points.

The high-level protocol is:

  1. The server sends the current model to all clients.
  2. Each client trains locally on its own data for a bit.
  3. Clients send their updates back, never their data.
  4. The server aggregates the updates into a new global model.

This keeps the data on-device while still letting you reap most of the benefits of pooling everyone’s signal.

Aggregation methods

FedSGD

The simplest approach is FedSGD. Every client takes one gradient step and sends back its gradient:

\[w^{t+1} = w^t - \eta \sum_{k=1}^K \frac{n_{k}}{n_{\text{total}}} \nabla F_{k}(w^t)\]

It works, but it requires a round of communication for every single optimizer step, which is expensive over real networks.

FedAvg

FedAvg is much more practical. Let each client take several local steps, then average their resulting models:

\[w^{t+1} = \sum_{k=1}^K \frac{n_{k}}{n_{\text{total}}} w_{k}^t\]

Communication drops dramatically, and in practice the convergence is competitive with FedSGD as long as the data isn’t wildly non-IID across clients.

Robust aggregation

One issue with naive averaging is that some clients might send garbage updates, either because of buggy software or because someone is trying to poison the model. Arithmetic means break in the presence of even a few outliers.

Geometric median

The geometric median is much more robust. You’re looking for the point that minimizes the sum of distances to all client updates:

\[\min_{z} \sum_{k=1}^m \alpha_{k} \|w_{k} - z\|\]

This is solved iteratively with the Weiszfeld algorithm:

\[z^{(i+1)} = \frac{\sum_{k=1}^m \beta_{k}^{(i)} w_{k}}{\sum_{k=1}^m \beta_{k}^{(i)}}, \quad \beta_{k}^{(i)} = \frac{\alpha_{k}}{\max\{\nu, \|w_{k} - z^{(i)}\|\}}\]

The iteration weights down updates that are far from the current center, which is exactly the behavior you want when malicious or buggy clients are sending things from the tails of the distribution.

Differential privacy

Even if clients only send gradients, a clever server might still reverse-engineer something about individual training points by analyzing those gradients carefully. Differential privacy is the standard fix.

The definition

A mechanism \(\mathcal{M}\) is \((\epsilon, \delta)\)-differentially private if, for any two datasets \(D\) and \(D'\) that differ by one record:

\[\Pr(\mathcal{M}(D) \in S) \leq e^\epsilon \Pr(\mathcal{M}(D') \in S) + \delta\]

In words: changing any single person’s data shouldn’t change the distribution of outputs by more than a small, controllable amount. The smaller \(\epsilon\) is, the stronger the privacy guarantee, and the more noise you have to add to satisfy it.

Adding DP to FedAvg

In DP-FedAvg, each client does two things to its gradient before sending:

  1. Clip the gradient to a maximum norm: \[\tilde{g}_{k} = \frac{g_{k}}{\max(1, \|g_{k}\|_{2} / C)}\]

  2. Add Gaussian noise: \[\hat{g}_k = \tilde{g}_k + \mathcal{N}(0, \sigma^2 C^2 \mathbf{I})\]

The implementation is short:

1
2
3
4
5
6
7
8
9
import numpy as np

def dp_clip_and_noise(grad: np.ndarray, clip_norm: float = 1.0, noise_multiplier: float = 1.0) -> np.ndarray:
norm = np.linalg.norm(grad)
if norm > clip_norm:
grad = (grad / norm) * clip_norm
noise_std = noise_multiplier * clip_norm
noise = np.random.normal(loc=0.0, scale=noise_std, size=grad.shape)
return grad + noise

After this, even if someone intercepts the gradient, the noisy version doesn’t reveal much about any individual training point.

Homomorphic encryption

Differential privacy bounds what you can infer from gradients. Homomorphic encryption goes further: the server never sees the actual gradients at all, only ciphertexts.

How it works

Homomorphic encryption lets you do math directly on encrypted values. If you have ciphertexts \(\text{Enc}(a)\) and \(\text{Enc}(b)\), the scheme defines an operation \(\oplus\) such that:

\[\text{Enc}(a) \oplus \text{Enc}(b) = \text{Enc}(a + b)\]

For federated learning the workflow becomes:

  1. Each client encrypts its update: \(c_k = \text{Enc}(w_k)\).
  2. The server sums the ciphertexts: \(c_{\text{sum}} = \sum_{k=1}^K c_k\).
  3. Whoever holds the decryption key recovers \(\sum_k w_k\).

In code, with a hypothetical HE backend:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from typing import Any, Protocol

class HEContext(Protocol):
def encrypt_zeros(self) -> Any: ...
def homomorphic_add(self, a: Any, b: Any) -> Any: ...
def decrypt(self, ciphertext: Any) -> Any: ...

def homomorphic_aggregate(encrypted_updates: list[Any], he_context: HEContext) -> Any:
c_sum = he_context.encrypt_zeros()
for c_u in encrypted_updates:
c_sum = he_context.homomorphic_add(c_sum, c_u)
return c_sum

def decrypt_sum(c_sum: Any, he_context: HEContext) -> Any:
return he_context.decrypt(c_sum)

Putting it all together

A complete federated round with differential privacy looks roughly like this:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np

def local_train(model_params, local_data, epochs: int = 1, lr: float = 0.01):
w = np.copy(model_params)
for _ in range(epochs):
grad = compute_gradient(w, local_data)
w = w - lr * grad
return w

def compute_gradient(w, data):
X, y = data
preds = X @ w
errs = preds - y
return (X.T @ errs) / len(X)

def federated_round(global_params, clients_data, clip_norm: float, noise_multiplier: float):
updated_params = []
for local_data in clients_data:
w_k = local_train(global_params, local_data)
grad_k = w_k - global_params
grad_k_dp = dp_clip_and_noise(grad_k, clip_norm, noise_multiplier)
updated_params.append(global_params + grad_k_dp)
return np.mean(updated_params, axis=0)

if __name__ == "__main__":
w_global = np.zeros(2)

clients_data = [
(np.array([[1, 2], [0, 1]]), np.array([1.0, 0.0])),
(np.array([[3, 2], [4, 1]]), np.array([2.0, 3.0])),
(np.array([[10, 2], [8, 2]]), np.array([4.0, 5.0])),
]

for t in range(5):
w_global = federated_round(w_global, clients_data, clip_norm=2.0, noise_multiplier=0.5)
print(f"Round {t}, global params = {w_global}")

You can layer the robust aggregation, differential privacy, and homomorphic encryption on top of one another. Each one defends against a different threat model, and in a real deployment you usually want at least the first two.

What broke when I built this

A few things hit me harder than I expected during implementation.

Non-IID data was the biggest pain point. Real clients don’t sample from the same distribution. Your phone’s photos look nothing like mine, and that asymmetry breaks several of the optimization assumptions that make FedAvg behave nicely on academic benchmarks.

Client dropout happens constantly. Phones go offline, people close apps, batteries die. Your aggregation strategy has to handle partial participation gracefully, since “wait for everyone” is not a strategy that survives contact with reality.

Privacy and accuracy trade off against each other. Adding noise improves privacy and hurts model performance. Encrypting everything adds computational and bandwidth overhead. You spend a lot of time tuning these knobs.

What stuck with me

The theory of federated learning is straightforward. The engineering is where it gets messy. Differential privacy and homomorphic encryption sound intimidating with all the Greek letters, but once you strip away the notation, the core ideas are not that complicated.

What I found most interesting was watching the techniques compose. Robust aggregation handles malicious clients. Differential privacy bounds what a curious server can infer. Homomorphic encryption keeps the server from seeing raw updates at all. Stack the three and you have a system that’s defensible against most realistic threat models, at the cost of a lot more engineering work than the centralized version would have needed.

If you want more detail, my thesis goes deeper, and the code shows what I actually built.