epochs = 1
def generate_images(model, test_input):
generated = model(test_input)
plt.figure(figsize = (8,6))
list_imgs = [test_input[0], generated[0]]
title = ['Original', 'Output (generated)']
for i in range(2):
plt.subplot(1, 2, i + 1)
plt.title(title[i])
plt.imshow(list_imgs[i] * 0.5 + 0.5)
plt.axis('off')
plt.show()
# Function to resize tensor to match a target tensor's shape
def resize_tensor_to_target(tensor, target_tensor):
target_shape = target_tensor.shape
return tf.image.resize(tensor, [target_shape[1], target_shape[2]])
# Updated training_step function with resizing before concatenation and shape debugging
@tf.function
def training_step(real_x, real_y):
with tf.GradientTape(persistent=True) as tape:
# Generator G: X -> Y
# Generator F: Y -> X
# Generate fake images
fake_y = generator_g(real_x, training=True)
cycled_x = generator_f(fake_y, training=True)
fake_x = generator_f(real_y, training=True)
cycled_y = generator_g(fake_x, training=True)
# Identity mapping
equal_x = generator_f(real_x, training=True)
equal_y = generator_g(real_y, training=True)
# Resize tensors to ensure they match before concatenation
fake_y_resized = resize_tensor_to_target(fake_y, fake_x)
cycled_y_resized = resize_tensor_to_target(cycled_y, cycled_x)
equal_y_resized = resize_tensor_to_target(equal_y, equal_x)
# Print tensor shapes for debugging
print("Shapes:")
print("real_x:", real_x.shape, "| real_y:", real_y.shape)
print("fake_x:", fake_x.shape, "| fake_y_resized:", fake_y_resized.shape)
print("cycled_x:", cycled_x.shape, "| cycled_y_resized:", cycled_y_resized.shape)
print("equal_x:", equal_x.shape, "| equal_y_resized:", equal_y_resized.shape)
# Discriminator outputs
discriminator_real_x = discriminator_x(real_x, training=True)
discriminator_fake_x = discriminator_x(fake_x, training=True)
discriminator_real_y = discriminator_y(real_y, training=True)
discriminator_fake_y = discriminator_y(fake_y_resized, training=True)
# Loss calculation
generator_g_loss = generator_loss(discriminator_fake_y)
generator_f_loss = generator_loss(discriminator_fake_x)
cycle_loss_total = cycle_loss(real_x, cycled_x) + cycle_loss(real_y, cycled_y_resized)
total_generator_g_loss = generator_g_loss + cycle_loss_total + identity_loss(real_y, equal_y_resized)
total_generator_f_loss = generator_f_loss + cycle_loss_total + identity_loss(real_x, equal_x)
discriminator_x_loss = discriminator_loss(discriminator_real_x, discriminator_fake_x)
discriminator_y_loss = discriminator_loss(discriminator_real_y, discriminator_fake_y)
# Compute gradients
generator_g_gradients = tape.gradient(total_generator_g_loss, generator_g.trainable_variables)
generator_f_gradients = tape.gradient(total_generator_f_loss, generator_f.trainable_variables)
discriminator_x_gradients = tape.gradient(discriminator_x_loss, discriminator_x.trainable_variables)
discriminator_y_gradients = tape.gradient(discriminator_y_loss, discriminator_y.trainable_variables)
# Apply gradients
optimizer_generator_g.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))
optimizer_generator_f.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables))
optimizer_discriminator_x.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))
optimizer_discriminator_y.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))
# Load InceptionV3 feature extractor from TensorFlow Hub
inception_v3 = hub.load('https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4')
def get_embeddings(images):
# Ensuring all images are resized to [299, 299] for compatibility with InceptionV3
images = tf.image.resize(images, [299, 299])
images = (images + 1.0) * 127.5 # Scale to [0, 255]
embeddings = inception_v3(images)
# Log shape for debugging
print(f"Embeddings shape: {embeddings.shape}")
return embeddings
def calculate_fid(real_embeddings, generated_embeddings):
# Ensure embeddings have the same number of dimensions before calculation
if real_embeddings.shape != generated_embeddings.shape:
min_shape = min(real_embeddings.shape[0], generated_embeddings.shape[0])
real_embeddings, generated_embeddings = real_embeddings[:min_shape], generated_embeddings[:min_shape]
mu1, sigma1 = np.mean(real_embeddings, axis=0), np.cov(real_embeddings, rowvar=False)
mu2, sigma2 = np.mean(generated_embeddings, axis=0), np.cov(generated_embeddings, rowvar=False)
diff = mu1 - mu2
covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2.0 * covmean)
return fid
def calculate_is(images, num_splits=10):
inception_model = InceptionV3(include_top=True, weights='imagenet')
images = preprocess_input(images)
preds = inception_model.predict(images)
preds = np.clip(preds, 1e-10, 1.0)
# Check shape consistency before proceeding
if preds.ndim != 2:
print("Prediction shape mismatch. Resizing.")
preds = preds.reshape(-1, preds.shape[-1])
scores = []
split_size = preds.shape[0] // num_splits
for i in range(num_splits):
part = preds[i * split_size:(i + 1) * split_size, :]
py = np.mean(part, axis=0)
scores.append(np.exp(np.mean([np.sum(p * np.log(p / py)) for p in part])))
return np.mean(scores), np.std(scores)
# Function to plot FID and IS
def plot_metrics(fid_scores, is_scores):
epochs = range(1, len(fid_scores) + 1)
# Plot FID scores
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs, fid_scores, 'bo-', label='FID')
plt.title('FID Score over Epochs')
plt.xlabel('Epochs')
plt.ylabel('FID Score')
plt.legend()
# Plot IS scores
plt.subplot(1, 2, 2)
plt.plot(epochs, [score[0] for score in is_scores], 'ro-', label='IS Mean')
plt.fill_between(epochs,
[score[0] - score[1] for score in is_scores],
[score[0] + score[1] for score in is_scores],
color='gray', alpha=0.2, label='IS Std Dev')
plt.title('Inception Score (IS) over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Inception Score')
plt.legend()
plt.show()
# Helper function to format time
def format_time(seconds):
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = int(seconds % 60)
return f"{hours}h {minutes}m {seconds}s"
# Modify the train function to use the formatted time output
def train(training_A, training_B, generator_g, generator_f, epochs, sample_A, testing_A, testing_B):
fid_scores = []
is_scores = []
for epoch in range(epochs):
start = time.time()
for img_x, img_y in tf.data.Dataset.zip((training_A, training_B)):
training_step(img_x, img_y)
clear_output(wait=True)
generate_images(generator_g, sample_A)
# Calculate FID and IS
real_images, generated_images = [], []
for img_x, img_y in tf.data.Dataset.zip((testing_A, testing_B)).take(100):
real_images.append(tf.image.resize(img_y, (299, 299))) # Resize for InceptionV3
generated_image = generator_g(img_x)
generated_images.append(tf.image.resize(generated_image, (299, 299)))
real_images_np = np.concatenate([img.numpy() for img in real_images], axis=0)
generated_images_np = np.concatenate([img.numpy() for img in generated_images], axis=0)
# Calculate FID
real_embeddings = get_embeddings(real_images_np)
generated_embeddings = get_embeddings(generated_images_np)
fid_score = calculate_fid(real_embeddings, generated_embeddings)
fid_scores.append(fid_score)
# Calculate IS
is_mean, is_std = calculate_is(generated_images_np)
is_scores.append((is_mean, is_std))
print(f'Epoch {epoch + 1} - FID: {fid_score}, IS: Mean -> {is_mean} STD -> {is_std}')
# Plot FID and IS
plot_metrics(fid_scores, is_scores)
# Save the final models after training in the current directory
generator_g.save(os.path.join(current_dir, 'generator_g_final.h5'))
generator_f.save(os.path.join(current_dir, 'generator_f_final.h5'))
print("Final model saved after training.")
# Example usage:
train(training_A, training_B, generator_g, generator_f, epochs, sample_A, testing_A, testing_B)
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-50-d3274f3c8612> in <cell line: 2>()
1 # Example usage:
----> 2 train(training_A, training_B, generator_g, generator_f, epochs, sample_A, testing_A, testing_B)
2 frames
/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
5981 def raise_from_not_ok_status(e, name) -> NoReturn:
5982 e.message += (" name: " + str(name if name is not None else ""))
-> 5983 raise core._status_to_exception(e) from None # pylint: disable=protected-access
5984
5985
InvalidArgumentError: Exception encountered when calling Concatenate.call().
{{function_node __wrapped__ConcatV2_N_2_device_/job:localhost/replica:0/task:0/device:GPU:0}} ConcatOp : Dimension 1 in both shapes must be equal: shape[0] = [1,6,4,512] vs. shape[1] = [1,5,3,512] [Op:ConcatV2] name: concat
Arguments received by Concatenate.call():
• inputs=['tf.Tensor(shape=(1, 6, 4, 512), dtype=float32)', 'tf.Tensor(shape=(1, 5, 3, 512), dtype=float32)']