How do I apply recursively to each submodule initialization?

In pytorch, nn.Module can use self.apply to recursively apply to each submodule to initialize all layers. But in Tensorflow, there are similar functions, so that I can perform similar operations in inheriting tf.keras.Module? :thinking:.

class TokenPose_TB_base(tf.keras.Module):
    def __init__(self, *, feature_size, patch_size, num_keypoints, dim, depth, heads, mlp_dim, apply_init=False,
                 apply_multi=True, hidden_heatmap_dim=64 * 6, heatmap_dim=64 * 48, heatmap_size=[64, 48], channels=3,
                 dropout=0., emb_dropout=0., pos_embedding_type="learnable", feature_dtype):
        super(TokenPose_TB_base, self).__init__()

        patch_dim = channels * patch_size[0] * patch_size[1]  
        assert pos_embedding_type in ['sine', 'learnable', 'sine-full']

        self.patch_size = patch_size  # (4, 3)
        self.heatmap_size = heatmap_size  # (64, 48)
        self.num_keypoints = num_keypoints  # 17
        self.num_patches = (feature_size[0] // (patch_size[0])) * (feature_size[1] // (patch_size[1]))  # 16*16 = 256
        self.pos_embedding_type = pos_embedding_type  
        self.all_attn = (self.pos_embedding_type == "sine-full")

        self.feature_dtype = feature_dtype  

        self.keypoint_token = tf.Variable(
            initial_value=tf.zeros(shape=(1, self.num_keypoints, dim), dtype=self.feature_dtype), trainable=True,
            name='keypoints_token')  # (1, 17, 192)
        h, w = feature_size[0] // (self.patch_size[0]), feature_size[1] // (self.patch_size[1])  # (16, 16)
        self._make_position_embedding(w, h, dim, pos_embedding_type)  #

        self.patch_to_embedding = tf.keras.layers.Dense(dim)  # 192
        self.drop_out = tf.keras.layers.Dropout(rate=emb_dropout)

        #   transformer
        self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout, num_keypoints=num_keypoints,
                                       all_attn=self.all_attn, scale_with_head=True)
        self.to_keypoint_token = Identity()

        self.mlp_head = tf.keras.Sequential([
            tf.keras.layers.LayerNormalization(),
            tf.keras.layers.Dense(hidden_heatmap_dim),
            tf.keras.layers.LayerNormalization(),
            tf.keras.layers.Dense(heatmap_dim),
        ]) if (dim <= hidden_heatmap_dim * 0.5 and apply_multi) else tf.keras.Sequential([
            tf.keras.layers.LayerNorm(dim),
            tf.keras.layers.Dense(heatmap_dim)]
        )
        self.keypoint_token = trunc_normal_(self.keypoint_token, std=0.02)

    def _init_weights(self, m):  
        print("Initialization...")
        if isinstance(m, tf.keras.layers.Dense):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, tf.keras.layers.Dense) and m.bias is not None:
                m.bias_initializer = tf.keras.initializers.constant(0)
        elif isinstance(m, tf.keras.layers.LayerNorm):
            m.beta_initializer = tf.keras.initializers.constant(0)
            m.gamma_initializer = tf.keras.initializers.constant(1)