Embedding weights tied to projection out logits

Hi,

Is there a concise way to implement embedding weights tied to projection out logits in TensorFlow?

In Flax, we are able to do this:

class Parallel_Transformer(nn.Module): 
  dim: int 
  num_tokens: int
  depth: int 
  dim_head: int = 64 
  heads: int = 8 
  ff_mult: int = 4
  
  @nn.compact
  def __call__(self, x):
    embed = nn.Embed(self.num_tokens, self.dim, embedding_init = nn.initializers.normal(stddev=0.02))
    x = embed(x)
    x = Transformer(dim=self.dim, depth=self.depth, heads=self.heads, dim_head=self.dim_head, ff_mult=self.ff_mult)(x)
    x = nn.LayerNorm(epsilon = 1e-5, use_bias = False)(x)
    out = embed.attend(x)
    return out

In PyTorch, we are able to do the same with:

def Parallel_Transformer(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):

    net = nn.Sequential(
        nn.Embedding(num_tokens, dim),
        *[
            Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
            for _ in range(depth)
        ],
        LayerNorm(dim),
        nn.Linear(dim, num_tokens, bias=False)
    )
    
    net[-1].weight = net[0].weight

    nn.init.normal_(net[0].weight, std=0.02)
    return net

I am currently working on a TensorFlow implementation here:

class Parallel_Transformer(tf.keras.Model):
    def __init__(self, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4): 
        super(PaLM, self).__init__()
        
        self.dim = dim
        self.num_tokens = num_tokens
        self.depth = depth
        self.dim_head = dim_head
        self.heads = heads
        self.ff_mult = ff_mult

        self.embedding = tf.keras.layers.Embedding(num_tokens, dim, embeddings_initializer='uniform')

        self.norm = tf.keras.layers.LayerNormalization()

        self.to_out = tf.keras.layers.Dense(num_tokens, use_bias = False)

    def call(self, x):
        embed = self.embedding(x)
        x = ParallelTransformer(self.dim, self.depth, self.heads, self.dim_head, self.ff_mult)(embed)
        x = self.norm(x)
        out = self.to_out(x)
        return out

I am unsure how to properly do the above in TensorFlow.

Any advice would be greatly appreciated.

Thank you,

Enrico

Hi @Enrico_Shippole,

Sorry for the delay in response.
I suggest you to use add_weight and matrix multiplication to avoid redundant layers and ensure that both layers (embedding and output logits) share the same weight matrix, similar to the PyTorch approach. Instead of relying on kernel_initializer for Dense layers in TF, I believe this the efficient approach.I’ve added a sample code gist for your reference.

Thank You.