I am wondering if I just need one line like this.
model.load_weights("/Users/anu/PycharmProjects/Siglip/gemma-keras-gemma_1.1_instruct_2b_en-v3/model.weights.h5")
And I need to boot the model.
ValueError: You are loading weights into a model that has not yet been built. Try building the model first by calling it on some data or by using build()
This is part of the ported code.
from SiglipVisionConfig import SiglipVisionConfig
from SiglipVisionModel import SiglipVisionModel
from PaliGemmaMultiModalProjector import PaliGemmaMultiModalProjector
from GemmaForCausalLM import GemmaForCausalLM
import tensorflow as tf
class PaliGemmaForConditionalGeneration(tf.keras.Model):
def __init__(self, config):
super().__init__()
self.config = config
self.vision_tower = SiglipVisionModel(config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.vocab_size
language_model = GemmaForCausalLM(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
def tie_weights(self):
return self.language_model.tie_weights()
def _merge_input_ids_with_image_features(
self,
image_features,
input_embeds,
input_ids,
attention_mask,
kv_cache
):
shape_list = tf.shape(image_features).numpy().tolist()
embed_dim = shape_list[2]
shape_list = tf.shape(input_ids).numpy().tolist()
batch_size = shape_list[0]
sequence_length = shape_list[1]
dtype = input_embeds.dtype
scaled_image_features = tf.divide( image_features, tf.sqrt(self.config.hidden_size))
#Combine all the image tokens, text tokens and mask all the padding tokens
final_embedding = tf.zeros((batch_size,sequence_length,embed_dim),dtype=input_embeds.dtype)
text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id)
image_mask = input_ids == self.image_token_index
pad_mask = input_ids == self.pad_token_id
# Step 1: Add an extra dimension (equivalent to unsqueeze(-1))
text_mask_expanded = tf.expand_dims(text_mask, axis=-1)
# Step 2: Tile the tensor to expand it along the new dimension
text_mask_expanded = tf.tile(text_mask, [1, 1, embed_dim])
image_mask_expanded = tf.expand_dims(image_mask, axis=-1)
image_mask_expanded = tf.tile(image_mask, [1, 1, embed_dim])
pad_mask_expanded = tf.expand_dims(pad_mask, axis=-1)
pad_mask_expanded = tf.tile(pad_mask, [1, 1, embed_dim])
final_embedding = result = tf.where(text_mask_expanded, input_embeds, final_embedding)
indices = tf.where(image_mask_expanded)
updates = tf.reshape(scaled_image_features, [-1])
final_embedding = tf.tensor_scatter_nd_update(final_embedding, indices, updates)
final_embedding = tf.where(pad_mask_expanded, tf.zeros_like(final_embedding))
q_len = tf.shape(input_embeds)[1]
if (kv_cache.num_items() == 0):
causal_mask = tf.fill( (batch_size, q_len, q_len), 0)
else:
assert q_len == 1, f" {q_len} is not 0"
kv_len = kv_cache.num_items() + q_len
causal_mask = tf.fill( (batch_size, q_len, kv_len), 0)
causal_mask = tf.expand_dims(causal_mask, axis=1)
if (kv_cache.num_items() > 1):
position_ids = tf.math.cumsum(attention_mask, axis=-1)[:, -1]
shape_list = tf.shape(position_ids).numpy().toList()
if len(shape_list) == 1:
position_ids = tf.expand_dims(position_ids, axis=0)
else:
cumsum = tf.math.cumsum(attention_mask, axis=-1)
mask = tf.equal(attention_mask , 0)
position_ids = tf.where(mask, tf.ones_like(cumsum), cumsum)
return final_embedding, causal_mask, position_ids
def call(self,
input_ids,
attention_mask,
pixel_values,
kv_cache):
tf.debugging.assert_equal(tf.reduce_all(tf.equal(attention_mask, 1)), True, message="Input cannot be padded")
# TODO Revise after implementing GemmaForCausalLM
input_embeds = self.language_model.get_input_embeddings(input_ids)
selected_image_features = self.vision_tower(tf.cast(pixel_values,input_embeds.dtype))
image_features = self.multi_modal_projector(selected_image_features)
input_embeds,attention_mask,position_ids=self._merge_input_ids_with_image_features(
self,
image_features,
input_embeds,
input_ids,
attention_mask,
kv_cache
)
outputs=self.language_model(
attention_mask,
position_ids,
input_embeds,
kv_cache
)
return outputs
Thanks