I’ve been trying to convert a tensorflow v1 project (Neurosat) that I’ve found to v2. I’ve went the route of subclassing Model and Layer. I’m also subclassing PyDataset to generate data.
For each training sample, I have the following data:
- Number of clauses
- Number of literals
- SparseTensor representing literal to clause mapping
My initial go at the data generator would generate a tensor containing the number of clauses, another tensor for number of literals, and finally a ragged stack of the literal to clause sparse tensors. So basically create a model with those 3 inputs. For each sample in the batch, I was going to use number of clauses to create clause hidden/cell LSTM state tensors and number of literals to create literal hidden/cell LSTM state tensors. Where I got stuck is all these tensors are symbolic. I’m not really sure how to read in the value. I’m sure it’s possible, as tf.print will print them.
The next iteration of my data generator, I decided to create the LSTM state tensors in the generator. Basically each batch would consist of ragged tensors for each of the LSTM state tensors and another ragged tensor for the literal to clause mapping. Unfortunately, I’m ending up with this error: Shapes used to initialize variables must be fully-defined (no
None dimensions). Received: shape=(None, 128) for variable path='neurosat_model/message_passing_layer/multilayer_perceptron_layer/dense/kernel'
I guess I’m wondering whether I can even do what I’m trying to do? Tensorflow is pretty new for me, so I don’t have a clear understanding of the limitations. I suspect I could try to mimic the structure of the original project, but I think doing it my way would be much cleaner.
For reference, here is my current code. No doubt there’s more things wrong with it, but I just haven’t been able to get past this current hurdle to know what else is wrong:
import tensorflow as tf
import random
import numpy as np
from pysat.solvers import Minisat22
from pysat.formula import CNF
from tensorflow.keras.layers import *
class NeurosatOptions:
def __init__(self):
#Network properties
self.dimensions = 128
self.num_message_layers = 3
self.num_rounds = 30
#Data generator properties
self.batch_size = 6
self.num_problems = 12
self.min_variables = 2
self.max_variables = 5
class ProblemDataGenerator(tf.keras.utils.PyDataset):
def __init__(self, options):
self.options = options
denom = tf.sqrt(tf.cast(options.dimensions, tf.float32))
self.literal_state_init_values = tf.math.divide(tf.random.normal(shape=[1, options.dimensions]), denom)
self.clause_state_init_values = tf.math.divide(tf.random.normal(shape=[1, options.dimensions]), denom)
def __getitem__(self, index):
literal_hidden_state = []
literal_cell_state = []
clause_hidden_state = []
clause_cell_state = []
literal_to_clauses = []
labels = []
for i in range(self.options.batch_size):
self.generate_problem(literal_hidden_state, literal_cell_state, clause_hidden_state, clause_cell_state, literal_to_clauses, labels)
literal_hidden_state = tf.ragged.stack(literal_hidden_state)
literal_cell_state = tf.ragged.stack(literal_cell_state)
clause_hidden_state = tf.ragged.stack(clause_hidden_state)
clause_cell_state = tf.ragged.stack(clause_cell_state)
literal_to_clauses = tf.ragged.stack(literal_to_clauses)
labels = tf.convert_to_tensor(labels)
return (literal_hidden_state, literal_cell_state, clause_hidden_state, clause_cell_state, literal_to_clauses), labels
def generate_problem(self, literal_hidden_state, literal_cell_state, clause_hidden_state, clause_cell_state, literal_to_clauses, labels):
num_variables = random.randint(self.options.min_variables, self.options.max_variables)
clauses=[]
with Minisat22() as solver:
while True:
num_clause_variables = 1 if random.random() < 0.3 else 2
num_clause_variables = num_clause_variables + np.random.geometric(0.4)
clause = self.generate_clause(num_variables, num_clause_variables)
solver.add_clause(clause)
is_sat = solver.solve()
if is_sat:
clauses.append(clause)
else:
break
unsatisfied_clause = clause
satisfied_clause = [-clause[0]] + clause[1:]
self.add_problem(clauses + [unsatisfied_clause], 0, num_variables, literal_hidden_state, literal_cell_state, clause_hidden_state, clause_cell_state, literal_to_clauses, labels)
self.add_problem(clauses + [satisfied_clause], 1, num_variables, literal_hidden_state, literal_cell_state, clause_hidden_state, clause_cell_state, literal_to_clauses, labels)
def add_problem(self, clauses, label, num_variables, literal_hidden_state, literal_cell_state, clause_hidden_state, clause_cell_state, literal_to_clauses, labels):
literal_hidden_state.append(self.build_literal_hidden_state(num_variables * 2))
literal_cell_state.append(self.build_literal_cell_state(num_variables * 2))
clause_hidden_state.append(self.build_clause_hidden_state(len(clauses)))
clause_cell_state.append(self.build_clause_cell_state(len(clauses)))
literal_to_clauses.append(self.build_literal_to_clause(clauses, num_variables))
labels.append(label)
def build_clause_cell_state(self, num_clauses):
return tf.zeros([num_clauses, self.options.dimensions])
def build_clause_hidden_state(self, num_clauses):
return tf.tile(self.clause_state_init_values, [num_clauses, 1])
def build_literal_cell_state(self, num_literals):
return tf.zeros([num_literals, self.options.dimensions])
def build_literal_hidden_state(self, num_literals):
return tf.tile(self.literal_state_init_values, [num_literals, 1])
def build_literal_to_clause(self, clauses, num_variables):
row_count = sum([len(clause) for clause in clauses])
indices = []
values = []
for index, clause in enumerate(clauses):
literals = [self.literal_to_variable(literal, num_variables) for literal in clause]
for literal in literals:
indices.append([literal, index])
values.append(1.0)
literal_to_clause = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=[num_variables * 2, len(clauses)])
literal_to_clause = tf.sparse.reorder(literal_to_clause)
return tf.sparse.to_dense(literal_to_clause)
def get_variable_and_sign(self, literal):
return abs(literal) - 1, literal < 0
def literal_to_variable(self, literal, num_variables):
variable, negated = self.get_variable_and_sign(literal)
return variable + num_variables if negated else variable
def generate_clause(self, num_variables, num_clause_variables):
clause_size = min(num_variables, num_clause_variables)
clause_variables = np.random.choice(num_variables, size=clause_size, replace=False)
clause_variables = [int(variable + 1) for variable in clause_variables]
return [variable if random.random() < 0.5 else -variable for variable in clause_variables]
def __len__(self):
return self.options.num_problems // (self.options.batch_size*2)
class MultilayerPerceptronLayer(tf.keras.layers.Layer):
def __init__(self, input_dimensions, num_layers, output_dimensions):
super(MultilayerPerceptronLayer, self).__init__()
self.input_dimensions = input_dimensions
self.num_layers = num_layers
self.output_dimensions = output_dimensions
def build(self, input_shape):
self.mlp_layers = [Dense(units=self.input_dimensions, activation='relu') for i in range(self.num_layers)] + \
[Dense(units=self.output_dimensions, activation='relu')]
super(MultilayerPerceptronLayer, self).build(input_shape)
def call(self, inputs):
for layer in self.mlp_layers:
inputs = layer(inputs)
return inputs
class MessagePassingLayer(tf.keras.layers.Layer):
def __init__(self, options):
super(MessagePassingLayer, self).__init__()
self.options = options
self.literal_to_clause_mlp_layer = MultilayerPerceptronLayer(options.dimensions, options.num_message_layers, options.dimensions)
self.clause_to_literal_mlp_layer = MultilayerPerceptronLayer(options.dimensions, options.num_message_layers, options.dimensions)
self.literal_state_update_layer = RNN(LSTMCell(options.dimensions), return_sequences=True, return_state=True)
self.clause_state_update_layer = RNN(LSTMCell(options.dimensions), return_sequences=True, return_state=True)
denom = tf.sqrt(tf.cast(options.dimensions, tf.float32))
self.literal_state_init_values = tf.math.divide(tf.random.normal(shape=[1, options.dimensions]), denom)
self.clause_state_init_values = tf.math.divide(tf.random.normal(shape=[1, options.dimensions]), denom)
def call(self, inputs):
literal_hidden_state, literal_cell_state, clause_hidden_state, clause_cell_state, literal_to_clauses = inputs
for round in range(self.options.num_rounds):
messages = self.literal_to_clause_mlp_layer(literal_hidden_state)
messages = tf.sparse.sparse_dense_matmul(literal_to_clause, messages, adjoint_a=True)
messages = tf.expand_dims(messages, axis=0)
_, clause_hidden_state, clause_cell_state = self.clause_state_update_layer(messages, initial_state=[clause_hidden_state, clause_cell_state])
messages = self.clause_to_literal_mlp_layer(clause_hidden_state)
messages = tf.sparse.sparse_dense_matmul(literal_to_clause, messages)
messages = tf.concat([messages, self.flip(literal_hidden_state)], axis=1)
messages = tf.expand_dims(messages, axis=0)
_, literal_hidden_state, literal_cell_state = self.clause_state_update_layer(messages, initial_state=[literal_hidden_state, literal_cell_state])
return literal_hidden_state
def flip(self, tensor):
row_count = tensor.shape[0]
tensor_1 = tensor[row_count//2:tensor.shape[0]]
tensor_2 = tensor[0:row_count//2]
return tf.concat([tensor_1, tensor_2], axis=0)
class VotingLayer(tf.keras.layers.Layer):
def __init__(self, options):
super(VotingLayer, self).__init__()
self.options = options
self.voting_mlp_layer = MultilayerPerceptronLayer(options.dimensions, options.num_message_layers, 1)
def call(self, inputs):
output = []
for messages in inputs:
votes = self.voting_mlp_layer(messages)
output.append(tf.reduce_mean(votes))
return output
class NeurosatModel(tf.keras.Model):
def __init__(self, options):
super(NeurosatModel, self).__init__()
self.options = options
self.message_passing_layer = MessagePassingLayer(options)
self.voting_layer = VotingLayer(options)
def call(self, inputs, training=False):
x = self.message_passing_layer(inputs)
return self.voting_layer(x)
options = NeurosatOptions()
gen = ProblemDataGenerator(options)
model = NeurosatModel(options)
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(gen, epochs=1, verbose=1)
model.summary()