I have implemented knn accuracy in tensorflow following this code which is written in pytorch. The aim is to determine how well my model is performing on image embeddings. My implementation is done using a callback
, however I get the same score of 1.17% on every epoch. Could someone kindly point me in the right direction or what I am doing wrong. I have included the code for your reference. Thank you in advance.
class KNNMonitor(tf.keras.callbacks.Callback):
def __init__(self,index,query,k=200,t=0.1):
super(KNNMonitor,self).__init__()
self.index = index
self.query = query
self.k = k
self.t = t
self.top_1 = 0.0
self.total_num = 0
def on_epoch_end(self,epoch,logs=None):
# self.top_1 = 0.0
# self.total_num = 0
self.feature_bank = []
self.index_labels = []
for x,y in self.index:
x = self.model.encoder(x)
self.feature_bank.append(x)
self.index_labels.append(y)
self.feature_bank = tf.concat(self.feature_bank,axis=0) # output[N,2048]
self.index_labels = tf.concat(self.index_labels,axis=0) #output[N]
for qx,qy in self.query:
qx = self.model.encoder(qx) # output[B,2048]
self.total_num += qx.shape[0]
cos_distance = tf.linalg.matmul(qx,self.feature_bank,transpose_b=True) # output[B,N]
k_values,k_indices = tf.raw_ops.TopKV2(input=cos_distance,k=self.k) # output (B,K)
sim_labels = tf.gather(self.index_labels,k_indices,axis=-1) # output(B,K)
sim_weight = tf.math.exp(k_values/self.t) # output[B,K]
# one-hot-encode topk labels
one_hot_label = tf.zeros([qx.shape[0]*self.k,C.CLASSES]) # output[B*K,C]
sl = tf.reshape(sim_labels,[1,-1]) # output[1,B*K]
r = tf.cast(
tf.reshape(tf.range(0,qx.shape[0]*self.k),[1,-1]),
dtype=tf.int64
) # output[1,B*K]
indices = tf.transpose(tf.concat([r,sl],axis=0)) # output[B*K,2]
one_hot_label = tf.tensor_scatter_nd_update(
one_hot_label,
indices=indices,
updates= tf.ones(qx.shape[0]*self.k)
) # output[B*K,C]
pred_scores = tf.math.reduce_sum(
tf.reshape(one_hot_label,[qx.shape[0],-1,C.CLASSES]) *
tf.expand_dims(sim_weight,axis=-1),
axis=1
) # output[B,C]
pred_scores_indices = tf.argsort(pred_scores,axis=-1,direction='DESCENDING') # output[B,C]
total = (pred_scores_indices[:, 0] == tf.cast(qy,dtype=tf.int32))
self.top_1 += tf.math.reduce_sum(tf.cast(total,dtype=tf.float32)).numpy()
print(f"knn accuracy:{self.top_1 / self.total_num *100:.2f}%")