I am implementing a client selection mechanism in TensorFlow Federated (TFF), where I need to dynamically choose the top_k
clients with the largest dataset sizes before performing model update aggregation (deltas
).
Problem Statement
In TFF, tff.federated_map()
executes functions independently on each client, meaning that each client only returns its own dataset size (sizes
), and these values remain at the client level.
I am unable to retrieve all dataset sizes on the server to perform sorting and selection of the
top_k
clients before aggregation.
What I Have Tried
- Using
tff.federated_map()
to retrievesizes
Issue: Each client only returns its own value, making it impossible to construct a global list of dataset sizes at the server.
- Using
tff.federated_collect()
Issue: This function has been deprecated in recent TFF versions.
- Using
tff.federated_aggregate()
to gather all dataset sizes at the server
Issue: Requires a reduction function compatible with
FederatedType
, causing placement errors.
- Using
tff.federated_zip(sizes)
Issue:
tff.federated_zip()
expects aStructType
, not aFederatedType
, so it does not work withsizes
.
- Attempting to collect
sizes
as atf.Tensor
and apply sorting
Issue:
TensorFlow EagerTensors
cannot be mixed with federated values, causing aTypeError: Expected a Python type that is convertible to a tff.Value
.
Here are my key questions: How can I retrieve all federated values
sizes
at the server before selection?
What is the correct TFF mechanism to sort and filter clients before aggregation?
How can I ensure that only the selected
top_k
clients contribute to the FedAvgAggregator
?
Current Code
Here is a summary of my implementation:
Local Client Update (client_update
)
@tf.function
def client_update(model, input_tuple, server_weights, client_optimizer):
client_id, dataset_size, dataset = input_tuple
client_weights = model.trainable_variables
tf.nest.map_structure(lambda x, y: x.assign(y), client_weights, server_weights)
optimizer_state = client_optimizer.initialize(
tf.nest.map_structure(lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
)
for batch in iter(dataset):
with tf.GradientTape() as tape:
outputs = model.forward_pass(batch)
grads = tape.gradient(outputs.loss, client_weights)
optimizer_state, updated_weights = client_optimizer.next(optimizer_state, client_weights, grads)
tf.nest.map_structure(lambda a, b: a.assign(b), client_weights, updated_weights)
return (tf.nest.map_structure(tf.subtract, client_weights, server_weights),
dataset_size, client_id)
Aggregator FedAvgAggregator
class FedAvgAggregator(tff.aggregators.WeightedAggregationFactory):
def create(self, value_type):
@tff.federated_computation()
def initialize_fn():
return tff.federated_value((), tff.SERVER)
@tff.tensorflow.computation(value_type)
def extract_delta(x):
return x[0]
@tff.tensorflow.computation(value_type)
def extract_weight(x):
return x[1]
@tff.federated_computation(
initialize_fn.type_signature.result,
tff.FederatedType(value_type, tff.CLIENTS)
)
def next_fn(state, value):
deltas = tff.federated_map(extract_delta, value)
weights = tff.federated_map(extract_weight, value)
# Selection of `top_k` clients (current issue)
filtered_deltas = tff.federated_map(select_first_two_deltas, deltas)
averaged_deltas = tff.federated_mean(deltas, weight=weights)
return tff.templates.MeasuredProcessOutput(
state,
averaged_deltas,
tff.federated_value((), tff.SERVER)
)
return tff.templates.AggregationProcess(initialize_fn, next_fn)
Federated Training Loop (run_one_round
)
@tff.federated_computation(federated_server_type, federated_dataset_new_type_for_run_one_round)
def run_one_round(server_state, federated_dataset):
server_weights_at_client = tff.federated_broadcast(server_state.trainable_weights)
clients_input = tff.federated_zip((federated_dataset, server_weights_at_client))
model_deltas_weights_ids = tff.federated_map(client_update_fn, clients_input)
deltas, sizes, ids = model_deltas_weights_ids
# Attempted collection of dataset sizes at the server (not working)
dataset_sizes = tff.federated_map(lambda x: x[1], model_deltas_weights_ids)
aggregator = FedAvgAggregator()
aggregation_process = aggregator.create(model_deltas_weights_ids.type_signature.member)
agg_result = aggregation_process.next(
aggregation_process.initialize(),
model_deltas_weights_ids
)
server_state = tff.federated_map(server_update_fn, (server_state, agg_result.result))
return server_state, dataset_sizes