Quantized LinearFxn
class QLinearFxn(Function):
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight, bias)
wq = expquantize(weight)
output = input.mm(wq.t())
if bias is not None:
bq = expquantize(bias)
output += bq.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_variables
grad_input = grad_weight = grad_bias = None
# Propagate gradient as if no quantization
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_bias
class QLinear(nn.Module):
def init(self, input_features, output_features, bias=True):
super(QLinear, self).init()
self.input_features = input_features
self.output_features = output_features
self.weight = nn.Parameter(th.Tensor(output_features, input_features))
nn.init.xavier_normal_(self.weight.data) # Initialize with Glorot Normal
if bias:
self.bias = nn.Parameter(th.Tensor(output_features))
nn.init.constant_(self.bias.data, 0) # Glorot Init of bias
else:
self.register_parameter('bias', None)
def forward(self, input):
return QLinearFxn.apply(input, self.weight, self.bias)
def getQweights(self):
return expquantize(self.weight.data)
class mGRUCell(nn.Module):
def init(self, input_size, hidden_size,use_quant=True):
super(mGRUCell, self).init()
self.input_size = input_size
self.hidden_size = hidden_size
self.u_size = input_size + hidden_size # concat x & h size
# mGRU weights (quantized or not)
if use_quant:
self.weight_zx = QLinear(self.u_size, hidden_size)
self.weight_hx = QLinear(self.u_size, hidden_size)
else:
self.weight_zx = nn.Linear(self.u_size, hidden_size)
self.weight_hx = nn.Linear(self.u_size, hidden_size)
def forward(self,x,state):
u = th.cat((x,state),1) # Concatenation of input & previous state
z= F.softsign(self.weight_zx(u))
g = F.softsign(self.weight_hx(u))
h = (1 - z)*state + z*g
return h