How can I ensure that only the selected top_k clients contribute to the FedAvgAggregator?

I am implementing a client selection mechanism in TensorFlow Federated (TFF), where I need to dynamically choose the top_k clients with the largest dataset sizes before performing model update aggregation (deltas).

Problem Statement

In TFF, tff.federated_map() executes functions independently on each client, meaning that each client only returns its own dataset size (sizes), and these values remain at the client level.
:backhand_index_pointing_right: I am unable to retrieve all dataset sizes on the server to perform sorting and selection of the top_k clients before aggregation.

What I Have Tried

  1. Using tff.federated_map() to retrieve sizes
  • :prohibited: Issue: Each client only returns its own value, making it impossible to construct a global list of dataset sizes at the server.
  1. Using tff.federated_collect()
  • :prohibited: Issue: This function has been deprecated in recent TFF versions.
  1. Using tff.federated_aggregate() to gather all dataset sizes at the server
  • :prohibited: Issue: Requires a reduction function compatible with FederatedType, causing placement errors.
  1. Using tff.federated_zip(sizes)
  • :prohibited: Issue: tff.federated_zip() expects a StructType, not a FederatedType, so it does not work with sizes.
  1. Attempting to collect sizes as a tf.Tensor and apply sorting
  • :prohibited: Issue: TensorFlow EagerTensors cannot be mixed with federated values, causing a TypeError: Expected a Python type that is convertible to a tff.Value.

Here are my key questions: :white_check_mark: How can I retrieve all federated values sizes at the server before selection?
:white_check_mark: What is the correct TFF mechanism to sort and filter clients before aggregation?
:white_check_mark: How can I ensure that only the selected top_k clients contribute to the FedAvgAggregator?

Current Code

Here is a summary of my implementation:

Local Client Update (client_update)

@tf.function
def client_update(model, input_tuple, server_weights, client_optimizer):
    client_id, dataset_size, dataset = input_tuple
    client_weights = model.trainable_variables
    tf.nest.map_structure(lambda x, y: x.assign(y), client_weights, server_weights)

    optimizer_state = client_optimizer.initialize(
        tf.nest.map_structure(lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
    )

    for batch in iter(dataset):
        with tf.GradientTape() as tape:            
            outputs = model.forward_pass(batch)
        grads = tape.gradient(outputs.loss, client_weights)
        optimizer_state, updated_weights = client_optimizer.next(optimizer_state, client_weights, grads)
        
        tf.nest.map_structure(lambda a, b: a.assign(b), client_weights, updated_weights)

    return (tf.nest.map_structure(tf.subtract, client_weights, server_weights),
            dataset_size, client_id)

Aggregator FedAvgAggregator

class FedAvgAggregator(tff.aggregators.WeightedAggregationFactory):
    def create(self, value_type):
        @tff.federated_computation()
        def initialize_fn():
            return tff.federated_value((), tff.SERVER)

        @tff.tensorflow.computation(value_type)
        def extract_delta(x):
            return x[0]

        @tff.tensorflow.computation(value_type)
        def extract_weight(x):
            return x[1]

        @tff.federated_computation(
            initialize_fn.type_signature.result,
            tff.FederatedType(value_type, tff.CLIENTS)
        )
        def next_fn(state, value):
            deltas = tff.federated_map(extract_delta, value)
            weights = tff.federated_map(extract_weight, value)

            # Selection of `top_k` clients (current issue)
            filtered_deltas = tff.federated_map(select_first_two_deltas, deltas)
            
            averaged_deltas = tff.federated_mean(deltas, weight=weights)

            return tff.templates.MeasuredProcessOutput(
                state,
                averaged_deltas,
                tff.federated_value((), tff.SERVER)
            )

        return tff.templates.AggregationProcess(initialize_fn, next_fn)

Federated Training Loop (run_one_round)

@tff.federated_computation(federated_server_type, federated_dataset_new_type_for_run_one_round)
def run_one_round(server_state, federated_dataset):
    server_weights_at_client = tff.federated_broadcast(server_state.trainable_weights)
    clients_input = tff.federated_zip((federated_dataset, server_weights_at_client))

    model_deltas_weights_ids = tff.federated_map(client_update_fn, clients_input)
    deltas, sizes, ids = model_deltas_weights_ids

    # Attempted collection of dataset sizes at the server (not working)
    dataset_sizes = tff.federated_map(lambda x: x[1], model_deltas_weights_ids)

    aggregator = FedAvgAggregator()
    aggregation_process = aggregator.create(model_deltas_weights_ids.type_signature.member)

    agg_result = aggregation_process.next(
        aggregation_process.initialize(),
        model_deltas_weights_ids
    )

    server_state = tff.federated_map(server_update_fn, (server_state, agg_result.result))
    return server_state, dataset_sizes