Data Scientist: Custom Aggregation Strategies
Learn how to implement specialized aggregation logic to handle complex federated learning requirements.
In a syft-flwr environment, the ServerApp manages global coordination by applying aggregation strategies to combine local model updates. This guide covers how to extend existing algorithms like FedAvg to meet specific research or production needs.
Custom strategies allow you to go beyond simple averaging. You can use them to save checkpoints, decay learning rates dynamically, or manage data that is statistically different (heterogeneous) across clients.
1. Extending FedAvg
The most efficient way to build a custom strategy is to inherit from an existing one, such as FedAvg.
- Method Overriding: You can override specific lifecycle methods like
aggregate_fit,configure_train, orevaluateto insert custom code. - Initialization: When creating your custom class, call
super().__init__(*args, **kwargs)to retain the standard parameters of the base strategy.
2. Custom Aggregation Logic
By overriding aggregate_fit, you control exactly how client weights are combined.
- Accessing Replies: This method receives a list of
repliesfrom clients, each containing anArrayRecord(model weights) and aMetricRecord(local metrics). - Mathematical Operations: You can perform weighted averages, exclude outliers, or apply specialized algorithms like FedNH to improve local model generalization.
- Return Values: Your method must return the new aggregated
ArrayRecordand any combinedMetricRecordto be used in the next round.
3. Model Checkpointing with FedAvgWithModelSaving strategy
Saving the global model throughout the process is critical for robustness.
- Round-Based Saving: Implement logic within the
evaluatemethod to save the model to disk at specific intervals (e.g., every 5 rounds). - Metric-Based Saving: Trigger a checkpoint only when a new "best" global accuracy is achieved during evaluation.
- Format Compatibility: Convert the
ArrayRecordinto standard formats like PyTorch.ptor TensorFlow.kerasbefore saving.
4. Strategy Configuration & Callbacks
You can parameterize existing strategies without writing a new class by using callbacks.
evaluate_fn: Pass a function to the strategy that evaluates the global model on a central server-side dataset after each round.on_fit_config_fn: Use this to dynamically adjust hyperparameters, such as decaying the learning rate as the experiment progresses.
5. Handling Heterogeneous Data
Real-world federated data is often non-IID (Independent and Identically Distributed), meaning distributions differ per site.
- Weighted Schemes: Adjust the "weight" of each client's contribution during aggregation based on their sample size or data quality.
- Architectural Mitigations: Techniques like replacing Batch Normalization with Layer Normalization can help stabilize training across diverse client datasets.
- Proximal Terms: Use strategies like FedProx to add a penalty term to local training, preventing clients from drifting too far from the global model.
Next Step: Now that you've mastered custom logic, proceed to Monitoring FL Training to learn how to visualize these aggregated results in real-time.