Improvement of API: Measures to avoid ValueError when pruning low manitude

Dear developers,

I have been try out the guide on pruning models (Pruning comprehensive guide  |  TensorFlow Model Optimization). My experiment is to prune a model with different schedules continuously. Therefore I wrote code like this:

params1 = {
         'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.5,
                                                                        final_sparsity=0.9,
                                                                        begin_step=0,
                                                                        end_step=10000)
}

params2 = {
         'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(
                                                                        0.75,
                                                                        begin_step=0,
                                                                        end_step=-1)
}

model1 = prune_low_magnitude(base_model, **pruning_params)

model_for_pruning.compile(optimizer='adam',
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'])

def apply_pruning(layer):
         return tfmot.sparsity.keras.prune_low_magnitude(layer, params2)

model_2 = tf.keras.models.clone_model(model1, clone_function=apply_pruning)

Then I got a ValueError:

Please initialize `Prune` with a supported layer. Layers should either be supported by the PruneRegistry (built-in keras layers) or should be a `PrunableLayer` instance, or should has a customer defined `get_prunable_weights` method. You passed: <class 'tensorflow_model_optimization.python.core.sparsity.keras.pruning_wrapper.PruneLowMagnitude'>

I know that it is because some layers that cannot be pruned is passed to prune_low_magnitude. I solved it with add an if isinstance(layer, tf.keras.layers.Dense). But I think the handler could be more friendly than just a critical error. Perhaps just ignore this kind of layers or report a warning is better.

BTW, if isinstance(layer, tf.keras.layers.Dense) seems to bypass the layers pruned before. Is there any solutions if I just want to prune the same layer using PolynomialDecay and ConstantSparsity successively? Thank you.

Adding @Rino_Lee for the visibility

Hi Hao,

I didn’t get what you mean “prune the same layer using PolynomialDecay and ConstantSparsity”. They are both exclusive, so you can choose only one pruning schedule between the two. Please explain how you want to prune your model, then we can help you.

The error is raised since you tried to Prune a layer which is already pruned. And you are right, the if statement let us bypassing the all pruned layers so the error was gone.