For my current reasearch, I want to add the Federated learning strategy SCAFFOLD from Karimireddy to my experiment. I am using tensorflow for my reasearch, however I can only find implementations in pytorch online, and I ran into some issues implementing it myself.
Below I attach the algorithm from the paper:
Basically, it just adds (code: line 10 -ci+c) some additional values (control variates) to the gradient before applying it to the model. There is a local control variates (c_i) which is different between all clients, and a global control variates (c) which gets updated every round after model aggregation.
Here the model weights also gets aggregated same as in FedAvg.
I tried to write this in a optimizer in tensorflow, due to the length of the code, I added it in pastebin: https://pastebin.com/ZtX4DvhG
Everything works and converges when directly using tf.SGD optimizer with the same learning rate. But when I use my implementation, it doesn’t converge and the model weights explode after few rounds. (approx. 5-7 aggregations). the weight and loss becomes: NaN.
On the client side, I do the following (I simplified it to keep it short):
# The optimizer Scaffold implementation is in the pastebin link above.
optimizer = Scaffold(learning_rate=hyperparams.SGD_learning_rate)
# load global model weights and global/local control variates to the optimizer
optimizer.set_controls(global_model.weights, hyperparams.scaffold)
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
model.fit()
# compute new local control variates store locally for next round
local_controls = optimizer.get_new_client_controls(
global_model.get_weights(),
model.get_weights(),
option=option,
)
# compute local control variates and send to server
local_controls_diff = (
[new - old for new, old in zip(local_controls, old_local_controls)]
if old_local_controls
else local_controls
)
The aggregation looks like this:
total_client_count = self.total_client_count
selected_clients_count = len(client_parameters)
global_params = self.global_weights
global_controls = self.global_controls
global_lr = self.global_lr
delta_weights = [
[ c_layer - g_layer for g_layer, c_layer in zip(global_params, client_i) ]
for client_i in client_parameters
]
delta_avg_weights = [
reduce(np.add, layer_updates) / selected_clients_count
for layer_updates in zip(*delta_weights)
]
delta_controls = [
[ c_layer - g_layer for g_layer, c_layer in zip(global_controls, client_i) ]
for client_i in client_contorls
]
delta_avg_controls = [
reduce(np.add, layer_updates) / selected_clients_count
for layer_updates in zip(*delta_controls)
]
# calc new global weights for next round
# x = x + lr_g * delta_x
new_global_weights = [
global_layer + global_lr * delta_avg
for global_layer, delta_avg in zip(global_params, delta_avg_weights)
]
# clac new global control variates for next round
# c = c + |S|/N * delta_ci
new_global_controls = [
global_layer + (selected_clients_count/total_client_count) * delta_avg
for global_layer, delta_avg in zip(global_controls, delta_avg_controls)
]
The new model weights and global control variates will be sent to each client.