In pytorch, nn.Module can use self.apply to recursively apply to each submodule to initialize all layers. But in Tensorflow, there are similar functions, so that I can perform similar operations in inheriting tf.keras.Module? .
class TokenPose_TB_base(tf.keras.Module):
def __init__(self, *, feature_size, patch_size, num_keypoints, dim, depth, heads, mlp_dim, apply_init=False,
apply_multi=True, hidden_heatmap_dim=64 * 6, heatmap_dim=64 * 48, heatmap_size=[64, 48], channels=3,
dropout=0., emb_dropout=0., pos_embedding_type="learnable", feature_dtype):
super(TokenPose_TB_base, self).__init__()
patch_dim = channels * patch_size[0] * patch_size[1]
assert pos_embedding_type in ['sine', 'learnable', 'sine-full']
self.patch_size = patch_size # (4, 3)
self.heatmap_size = heatmap_size # (64, 48)
self.num_keypoints = num_keypoints # 17
self.num_patches = (feature_size[0] // (patch_size[0])) * (feature_size[1] // (patch_size[1])) # 16*16 = 256
self.pos_embedding_type = pos_embedding_type
self.all_attn = (self.pos_embedding_type == "sine-full")
self.feature_dtype = feature_dtype
self.keypoint_token = tf.Variable(
initial_value=tf.zeros(shape=(1, self.num_keypoints, dim), dtype=self.feature_dtype), trainable=True,
name='keypoints_token') # (1, 17, 192)
h, w = feature_size[0] // (self.patch_size[0]), feature_size[1] // (self.patch_size[1]) # (16, 16)
self._make_position_embedding(w, h, dim, pos_embedding_type) #
self.patch_to_embedding = tf.keras.layers.Dense(dim) # 192
self.drop_out = tf.keras.layers.Dropout(rate=emb_dropout)
# transformer
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout, num_keypoints=num_keypoints,
all_attn=self.all_attn, scale_with_head=True)
self.to_keypoint_token = Identity()
self.mlp_head = tf.keras.Sequential([
tf.keras.layers.LayerNormalization(),
tf.keras.layers.Dense(hidden_heatmap_dim),
tf.keras.layers.LayerNormalization(),
tf.keras.layers.Dense(heatmap_dim),
]) if (dim <= hidden_heatmap_dim * 0.5 and apply_multi) else tf.keras.Sequential([
tf.keras.layers.LayerNorm(dim),
tf.keras.layers.Dense(heatmap_dim)]
)
self.keypoint_token = trunc_normal_(self.keypoint_token, std=0.02)
def _init_weights(self, m):
print("Initialization...")
if isinstance(m, tf.keras.layers.Dense):
trunc_normal_(m.weight, std=.02)
if isinstance(m, tf.keras.layers.Dense) and m.bias is not None:
m.bias_initializer = tf.keras.initializers.constant(0)
elif isinstance(m, tf.keras.layers.LayerNorm):
m.beta_initializer = tf.keras.initializers.constant(0)
m.gamma_initializer = tf.keras.initializers.constant(1)