I’ tring to use jit_complile to accelerate training of my model, and after that the loop in my code didn’t work anymore. Here is the code of training:
@tf.function(autograph=True, jit_compile=True)
def train_step(self, label, fea_ids, fea_vals, model):
with tf.GradientTape() as tape:
pred = model([fea_ids, fea_vals])
loss = model.loss(label, pred)
loss = loss - 0.5 * pred
gradients = tape.gradient(loss, model.trainable_weights)
model.optimizer.apply_gradients(zip(gradients, model.trainable_weights))
return tf.reduce_mean(loss)
And here is the loop:
single_mask = tf.where(feat_index > 0, True, False)
# before_multihot_single_mask=single_mask
def single_mask_while_loop(single_mask):
def cond(i, single_mask):
return i < self.len_multihot_fea
def body(i, single_mask):
single_mask = single_mask & (tf.where(feat_index < self.multi_hot_fea_tf[i, 0], True, False) | tf.where(feat_index>=self.multi_hot_fea_tf[i, 1], True, False))
return i + 1, single_mask
i = tf.constant(0, dtype=tf.int64)
i, single_mask = tf.while_loop(cond, body, [i, single_mask])
return single_mask
single_mask = single_mask_while_loop(single_mask)
After check, single_mask didn’t change after the loop, how can I solve this? Thanks for help.