How to encode multiple categorical features in tensorflow2 efficiently? (The embedding layer costs too much storage. )

I want to seek help in writing the embedding block for multiple categorical features in tensorflow2.

The shape of my inputs is batch_size * num_features, each column is a categorical features encoded by intergers.
(for example, suppose the first column is $X_0$, and $X_0$ has 10 classes, then $X_0\in{0,1,2,\cdots,9}$)

Now I want to get the embeddings of the inputs.
Currently, I create a list of layers.Embedding() to get the embedding for each feature (the code is as following).

But in this way, when I save the model, it is insanely large.
For example, when num_features=3000 and each features has 10 classes, the total parameters is 90000 (3000*10*3), whose storage should be no more than 1M ($90000*4/1024/1024\approx 0.34$).
However, when I use model.save() to save the model, the saved_model.pd model file covers 59M. I do not know how to write the model structure to make the model file normal.

I am confused by this question for days. Thank you sooooo much for any comments!

My code is as following. The version of tf is tensorflow2.2.0, python is python 3.6.10.


class Module_Embedding(tf.keras.Model):

    def __init__(self, num_class_list: List[int], dim_embeddings: int = 4):
        super().__init__()
        self.num_class_list = num_class_list
        self.dim_embeddings = dim_embeddings
        self.dim_ft_categorical = len(num_class_list)
        self.embeddings = [layers.Embedding(nc, self.dim_embeddings, embeddings_regularizer=tf.keras.regularizers.l2(0.2)) for nc in self.num_class_list]

    def call(self, inputs):
        ft_categorical_embed_list = [self.embeddings[i](inputs[:,i]) for i in range(self.dim_ft_categorical)]
        ft_categorical_embed =  tf.concat(ft_categorical_embed_list, axis=-1)
        return ft_categorical_embed


num_features=3000
num_class_list = [10 for _ in range(num_features)]
inputs = tf.random.uniform((n, m), maxval=min(num_class_list))
model = Module_Embedding(num_class_list=num_class_list, dim_embeddings=3)
y = model(inputs)
model.summary()
model.save('./2_model/temp/model_check_0730')
1 Like