I am trying to convert a pytorch script into tensorflow, how may I do so? Do we do it line by line or does the overall structure change in tensorflow?
Please, someone help me with this and provide some usefull link for this!
The code refers to graph convolution network. I see that pytorch_geometric has predefined modules like MessagePassing from which GCNConv is inheriting.
Is there any similar module in tensorflow?
GCN Code:
import torch
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
from inits import glorot, zeros
import pdb
class GCNConv(MessagePassing):
def __init__(self,
in_channels,
out_channels,
improved=False,
cached=False,
bias=True):
super(GCNConv, self).__init__('add')
self.in_channels = in_channels
self.out_channels = out_channels
self.improved = improved
self.cached = cached
self.cached_result = None
self.weight = Parameter(torch.Tensor(in_channels, out_channels))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
zeros(self.bias)
self.cached_result = None
@staticmethod
def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1), ),
dtype=dtype,
device=edge_index.device)
edge_weight = edge_weight.view(-1)
assert edge_weight.size(0) == edge_index.size(1)
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
loop_weight = torch.full((num_nodes, ),
1 if not improved else 2,
dtype=edge_weight.dtype,
device=edge_weight.device)
edge_weight = torch.cat([edge_weight, loop_weight], dim=0)
row, col = edge_index
deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-1)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return edge_index, deg_inv_sqrt[col] * edge_weight
def forward(self, x, edge_index, edge_weight=None):
""""""
x = torch.matmul(x, self.weight)
if not self.cached or self.cached_result is None:
edge_index, norm = self.norm(edge_index, x.size(0), edge_weight,
self.improved, x.dtype)
self.cached_result = edge_index, norm
edge_index, norm = self.cached_result
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
The script is of a graph convolutional network. (source: source code )