Hi,
Is there a concise way to implement embedding weights tied to projection out logits in TensorFlow?
In Flax, we are able to do this:
class Parallel_Transformer(nn.Module):
dim: int
num_tokens: int
depth: int
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
@nn.compact
def __call__(self, x):
embed = nn.Embed(self.num_tokens, self.dim, embedding_init = nn.initializers.normal(stddev=0.02))
x = embed(x)
x = Transformer(dim=self.dim, depth=self.depth, heads=self.heads, dim_head=self.dim_head, ff_mult=self.ff_mult)(x)
x = nn.LayerNorm(epsilon = 1e-5, use_bias = False)(x)
out = embed.attend(x)
return out
In PyTorch, we are able to do the same with:
def Parallel_Transformer(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
net = nn.Sequential(
nn.Embedding(num_tokens, dim),
*[
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
for _ in range(depth)
],
LayerNorm(dim),
nn.Linear(dim, num_tokens, bias=False)
)
net[-1].weight = net[0].weight
nn.init.normal_(net[0].weight, std=0.02)
return net
I am currently working on a TensorFlow implementation here:
class Parallel_Transformer(tf.keras.Model):
def __init__(self, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
super(PaLM, self).__init__()
self.dim = dim
self.num_tokens = num_tokens
self.depth = depth
self.dim_head = dim_head
self.heads = heads
self.ff_mult = ff_mult
self.embedding = tf.keras.layers.Embedding(num_tokens, dim, embeddings_initializer='uniform')
self.norm = tf.keras.layers.LayerNormalization()
self.to_out = tf.keras.layers.Dense(num_tokens, use_bias = False)
def call(self, x):
embed = self.embedding(x)
x = ParallelTransformer(self.dim, self.depth, self.heads, self.dim_head, self.ff_mult)(embed)
x = self.norm(x)
out = self.to_out(x)
return out
I am unsure how to properly do the above in TensorFlow.
Any advice would be greatly appreciated.
Thank you,
Enrico