Skip to main content

Data Scientist: Custom Aggregation Strategies

Learn how to implement specialized aggregation logic to handle complex federated learning requirements.

In a syft-flwr environment, the ServerApp manages global coordination by applying aggregation strategies to combine local model updates. This guide covers how to extend existing algorithms like FedAvg to meet specific research or production needs.

Custom strategies allow you to go beyond simple averaging. You can use them to save checkpoints, decay learning rates dynamically, or manage data that is statistically different (heterogeneous) across clients.

1. Extending FedAvg

The most efficient way to build a custom strategy is to inherit from an existing one, such as FedAvg.

  • Method Overriding: You can override specific lifecycle methods like aggregate_fit, configure_train, or evaluate to insert custom code.
  • Initialization: When creating your custom class, call super().__init__(*args, **kwargs) to retain the standard parameters of the base strategy.

2. Custom Aggregation Logic

By overriding aggregate_fit, you control exactly how client weights are combined.

  • Accessing Replies: This method receives a list of replies from clients, each containing an ArrayRecord (model weights) and a MetricRecord (local metrics).
  • Mathematical Operations: You can perform weighted averages, exclude outliers, or apply specialized algorithms like FedNH to improve local model generalization.
  • Return Values: Your method must return the new aggregated ArrayRecord and any combined MetricRecord to be used in the next round.

3. Model Checkpointing with FedAvgWithModelSaving strategy

Saving the global model throughout the process is critical for robustness.

  • Round-Based Saving: Implement logic within the evaluate method to save the model to disk at specific intervals (e.g., every 5 rounds).
  • Metric-Based Saving: Trigger a checkpoint only when a new "best" global accuracy is achieved during evaluation.
  • Format Compatibility: Convert the ArrayRecord into standard formats like PyTorch .pt or TensorFlow .keras before saving.

4. Strategy Configuration & Callbacks

You can parameterize existing strategies without writing a new class by using callbacks.

  • evaluate_fn: Pass a function to the strategy that evaluates the global model on a central server-side dataset after each round.
  • on_fit_config_fn: Use this to dynamically adjust hyperparameters, such as decaying the learning rate as the experiment progresses.

5. Handling Heterogeneous Data

Real-world federated data is often non-IID (Independent and Identically Distributed), meaning distributions differ per site.

  • Weighted Schemes: Adjust the "weight" of each client's contribution during aggregation based on their sample size or data quality.
  • Architectural Mitigations: Techniques like replacing Batch Normalization with Layer Normalization can help stabilize training across diverse client datasets.
  • Proximal Terms: Use strategies like FedProx to add a penalty term to local training, preventing clients from drifting too far from the global model.

Next Step: Now that you've mastered custom logic, proceed to Monitoring FL Training to learn how to visualize these aggregated results in real-time.