I am porting some of my tf1 code to tf2 and am wondering how can I define a model for training and then reuse that same model at intermediary steps for predictions? Example:
x = tf.keras.Input(shape=(5,))
s = tf.keras.Sequential([...])(x)
# now do several things with s independently
a = tf.keras.Dense(...)(s)
b = tf.keras.Sequential(...)(s)
# do some more things with a and b and calculate the output for the model
output = tf.keras.Sequential([tf.keras.Dense(tf.concat([a,b])), ...])
model = tf.keras.Model(inputs=x, outputs=output)
Now, I want to be able to just calculate a and b (without calculating output) but also including the output.
I could write models for a, a and b and another for b where all the models get input x. But what if I want to calculate a,b and output. Now if I would do that using all the models I have previously defined, I would forward pass the graph multiple times, for each output that I need. I feel like I am missing something here.
Alternatively, I could consecutively define models that don’t have the original input as input, but instead I layer multiple input nodes within my code everytime I am working the next part of the graph like so
...
a = tf.keras.Dense(...)(s)
b = tf.keras.Sequential(...)(s)
model_a = tf.keras.Model(inputs=input, output=a)
model_b = tf.keras.Model(inputs=input, output=b)
input_a = tf.keras.Input(shape_of_a)
input_b = tf.keras.Input(shape_of_b)
output = tf.keras.Sequential([tf.keras.Dense(tf.concat([a,b])), ...])
model = tf.keras.Model(inputs=[input_a, input_b], outputs=output)
This is all so complicated and ugly… Now during training I have to first pass forward through each model on the way and so on.
It seems that I am in a bind. Either define multiple models using the same input, each time rewriting at least some of my code (which is in itself ugly) and accept the inefficiency during non-training when calculating a lot of nodes;
Or: Have tons of input nodes and complicated feed forward structures
Coming from tf1 where all of this was trivial, I feel like I am missunderstanding something very fundamental about how tf2 is supposed to work. Please feel free to answer generally what I am missing, my code was just for illustration.