I was trying to translate a PyTorch code (a single module) to TensorFlow but faced some obstacles. Below are the both code and trails. How should I do this?
In pytorch, it has
import torch.nn as nn
import torch
class Blcoks(nn.Module):
def __init__(self):
super().__init__()
def forward(self, local_feat, global_feat):
global_feat_norm = torch.norm(global_feat, p=2, dim=1)
projection = torch.bmm(global_feat.unsqueeze(1), torch.flatten(local_feat, start_dim=2))
projection = torch.bmm(global_feat.unsqueeze(2), projection).view(local_feat.size())
projection = projection / (global_feat_norm * global_feat_norm).view(-1, 1, 1, 1)
orthogonal_comp = local_feat - projection
global_feat = global_feat.unsqueeze(-1).unsqueeze(-1)
return torch.cat([global_feat.expand(orthogonal_comp.size()),
orthogonal_comp], dim=1)
local_ = torch.ones(1, 20, 32, 32)
global_ = torch.ones(1, 20)
x = OrthogonalFusion()
x(local_, global_).shape
torch.Size([1, 40, 32, 32])
Above is the working code. Now, I’ve tried to translate into tf but fails.
Major issue,