How do I apply recursively to each submodule initialization?

In pytorch, nn.Module can use self.apply to recursively apply to each submodule to initialize all layers. But in Tensorflow, there are similar functions, so that I can perform similar operations in inheriting tf.keras.Module? :thinking:.

class TokenPose_TB_base(tf.keras.Module):
    def __init__(self, *, feature_size, patch_size, num_keypoints, dim, depth, heads, mlp_dim, apply_init=False,
                 apply_multi=True, hidden_heatmap_dim=64 * 6, heatmap_dim=64 * 48, heatmap_size=[64, 48], channels=3,
                 dropout=0., emb_dropout=0., pos_embedding_type="learnable", feature_dtype):
        super(TokenPose_TB_base, self).__init__()

        patch_dim = channels * patch_size[0] * patch_size[1]  
        assert pos_embedding_type in ['sine', 'learnable', 'sine-full']

        self.patch_size = patch_size  # (4, 3)
        self.heatmap_size = heatmap_size  # (64, 48)
        self.num_keypoints = num_keypoints  # 17
        self.num_patches = (feature_size[0] // (patch_size[0])) * (feature_size[1] // (patch_size[1]))  # 16*16 = 256
        self.pos_embedding_type = pos_embedding_type  
        self.all_attn = (self.pos_embedding_type == "sine-full")

        self.feature_dtype = feature_dtype  

        self.keypoint_token = tf.Variable(
            initial_value=tf.zeros(shape=(1, self.num_keypoints, dim), dtype=self.feature_dtype), trainable=True,
            name='keypoints_token')  # (1, 17, 192)
        h, w = feature_size[0] // (self.patch_size[0]), feature_size[1] // (self.patch_size[1])  # (16, 16)
        self._make_position_embedding(w, h, dim, pos_embedding_type)  #

        self.patch_to_embedding = tf.keras.layers.Dense(dim)  # 192
        self.drop_out = tf.keras.layers.Dropout(rate=emb_dropout)

        #   transformer
        self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout, num_keypoints=num_keypoints,
                                       all_attn=self.all_attn, scale_with_head=True)
        self.to_keypoint_token = Identity()

        self.mlp_head = tf.keras.Sequential([
            tf.keras.layers.LayerNormalization(),
            tf.keras.layers.Dense(hidden_heatmap_dim),
            tf.keras.layers.LayerNormalization(),
            tf.keras.layers.Dense(heatmap_dim),
        ]) if (dim <= hidden_heatmap_dim * 0.5 and apply_multi) else tf.keras.Sequential([
            tf.keras.layers.LayerNorm(dim),
            tf.keras.layers.Dense(heatmap_dim)]
        )
        self.keypoint_token = trunc_normal_(self.keypoint_token, std=0.02)

    def _init_weights(self, m):  
        print("Initialization...")
        if isinstance(m, tf.keras.layers.Dense):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, tf.keras.layers.Dense) and m.bias is not None:
                m.bias_initializer = tf.keras.initializers.constant(0)
        elif isinstance(m, tf.keras.layers.LayerNorm):
            m.beta_initializer = tf.keras.initializers.constant(0)
            m.gamma_initializer = tf.keras.initializers.constant(1)

Hi @SuSei, thanks for posting this issue.

TensorFlow models have a submodules property that lets you iterate over all layers and submodules. You can read more about it here.

Let me know if that works for you.