TensorFlow InvalidArgumentError: Concatenation dimension mismatch in ConcatOp - Shapes do not match

epochs = 1
def generate_images(model, test_input):
  generated = model(test_input)
  plt.figure(figsize = (8,6))
  list_imgs = [test_input[0], generated[0]]
  title = ['Original', 'Output (generated)']
  for i in range(2):
    plt.subplot(1, 2, i + 1)
    plt.title(title[i])
    plt.imshow(list_imgs[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

# Function to resize tensor to match a target tensor's shape
def resize_tensor_to_target(tensor, target_tensor):
    target_shape = target_tensor.shape
    return tf.image.resize(tensor, [target_shape[1], target_shape[2]])


# Updated training_step function with resizing before concatenation and shape debugging
@tf.function
def training_step(real_x, real_y):
    with tf.GradientTape(persistent=True) as tape:
        # Generator G: X -> Y
        # Generator F: Y -> X

        # Generate fake images
        fake_y = generator_g(real_x, training=True)
        cycled_x = generator_f(fake_y, training=True)

        fake_x = generator_f(real_y, training=True)
        cycled_y = generator_g(fake_x, training=True)

        # Identity mapping
        equal_x = generator_f(real_x, training=True)
        equal_y = generator_g(real_y, training=True)

        # Resize tensors to ensure they match before concatenation
        fake_y_resized = resize_tensor_to_target(fake_y, fake_x)
        cycled_y_resized = resize_tensor_to_target(cycled_y, cycled_x)
        equal_y_resized = resize_tensor_to_target(equal_y, equal_x)

        # Print tensor shapes for debugging
        print("Shapes:")
        print("real_x:", real_x.shape, "| real_y:", real_y.shape)
        print("fake_x:", fake_x.shape, "| fake_y_resized:", fake_y_resized.shape)
        print("cycled_x:", cycled_x.shape, "| cycled_y_resized:", cycled_y_resized.shape)
        print("equal_x:", equal_x.shape, "| equal_y_resized:", equal_y_resized.shape)

        # Discriminator outputs
        discriminator_real_x = discriminator_x(real_x, training=True)
        discriminator_fake_x = discriminator_x(fake_x, training=True)
        discriminator_real_y = discriminator_y(real_y, training=True)
        discriminator_fake_y = discriminator_y(fake_y_resized, training=True)

        # Loss calculation
        generator_g_loss = generator_loss(discriminator_fake_y)
        generator_f_loss = generator_loss(discriminator_fake_x)
        cycle_loss_total = cycle_loss(real_x, cycled_x) + cycle_loss(real_y, cycled_y_resized)
        total_generator_g_loss = generator_g_loss + cycle_loss_total + identity_loss(real_y, equal_y_resized)
        total_generator_f_loss = generator_f_loss + cycle_loss_total + identity_loss(real_x, equal_x)
        discriminator_x_loss = discriminator_loss(discriminator_real_x, discriminator_fake_x)
        discriminator_y_loss = discriminator_loss(discriminator_real_y, discriminator_fake_y)

    # Compute gradients
    generator_g_gradients = tape.gradient(total_generator_g_loss, generator_g.trainable_variables)
    generator_f_gradients = tape.gradient(total_generator_f_loss, generator_f.trainable_variables)
    discriminator_x_gradients = tape.gradient(discriminator_x_loss, discriminator_x.trainable_variables)
    discriminator_y_gradients = tape.gradient(discriminator_y_loss, discriminator_y.trainable_variables)

    # Apply gradients
    optimizer_generator_g.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))
    optimizer_generator_f.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables))
    optimizer_discriminator_x.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))
    optimizer_discriminator_y.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))



# Load InceptionV3 feature extractor from TensorFlow Hub
inception_v3 = hub.load('https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4')


def get_embeddings(images):
    # Ensuring all images are resized to [299, 299] for compatibility with InceptionV3
    images = tf.image.resize(images, [299, 299])
    images = (images + 1.0) * 127.5  # Scale to [0, 255]
    embeddings = inception_v3(images)

    # Log shape for debugging
    print(f"Embeddings shape: {embeddings.shape}")
    return embeddings



def calculate_fid(real_embeddings, generated_embeddings):
    # Ensure embeddings have the same number of dimensions before calculation
    if real_embeddings.shape != generated_embeddings.shape:
        min_shape = min(real_embeddings.shape[0], generated_embeddings.shape[0])
        real_embeddings, generated_embeddings = real_embeddings[:min_shape], generated_embeddings[:min_shape]

    mu1, sigma1 = np.mean(real_embeddings, axis=0), np.cov(real_embeddings, rowvar=False)
    mu2, sigma2 = np.mean(generated_embeddings, axis=0), np.cov(generated_embeddings, rowvar=False)
    diff = mu1 - mu2
    covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid


def calculate_is(images, num_splits=10):
    inception_model = InceptionV3(include_top=True, weights='imagenet')
    images = preprocess_input(images)
    preds = inception_model.predict(images)
    preds = np.clip(preds, 1e-10, 1.0)

    # Check shape consistency before proceeding
    if preds.ndim != 2:
        print("Prediction shape mismatch. Resizing.")
        preds = preds.reshape(-1, preds.shape[-1])

    scores = []
    split_size = preds.shape[0] // num_splits
    for i in range(num_splits):
        part = preds[i * split_size:(i + 1) * split_size, :]
        py = np.mean(part, axis=0)
        scores.append(np.exp(np.mean([np.sum(p * np.log(p / py)) for p in part])))

    return np.mean(scores), np.std(scores)

# Function to plot FID and IS
def plot_metrics(fid_scores, is_scores):
    epochs = range(1, len(fid_scores) + 1)

    # Plot FID scores
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, fid_scores, 'bo-', label='FID')
    plt.title('FID Score over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('FID Score')
    plt.legend()

    # Plot IS scores
    plt.subplot(1, 2, 2)
    plt.plot(epochs, [score[0] for score in is_scores], 'ro-', label='IS Mean')
    plt.fill_between(epochs,
                     [score[0] - score[1] for score in is_scores],
                     [score[0] + score[1] for score in is_scores],
                     color='gray', alpha=0.2, label='IS Std Dev')
    plt.title('Inception Score (IS) over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Inception Score')
    plt.legend()

    plt.show()

# Helper function to format time
def format_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    seconds = int(seconds % 60)
    return f"{hours}h {minutes}m {seconds}s"

# Modify the train function to use the formatted time output
def train(training_A, training_B, generator_g, generator_f, epochs, sample_A, testing_A, testing_B):
    fid_scores = []
    is_scores = []

    for epoch in range(epochs):
        start = time.time()

        for img_x, img_y in tf.data.Dataset.zip((training_A, training_B)):
            training_step(img_x, img_y)

        clear_output(wait=True)
        generate_images(generator_g, sample_A)

        # Calculate FID and IS
        real_images, generated_images = [], []

        for img_x, img_y in tf.data.Dataset.zip((testing_A, testing_B)).take(100):
            real_images.append(tf.image.resize(img_y, (299, 299)))  # Resize for InceptionV3
            generated_image = generator_g(img_x)
            generated_images.append(tf.image.resize(generated_image, (299, 299)))

        real_images_np = np.concatenate([img.numpy() for img in real_images], axis=0)
        generated_images_np = np.concatenate([img.numpy() for img in generated_images], axis=0)

        # Calculate FID
        real_embeddings = get_embeddings(real_images_np)
        generated_embeddings = get_embeddings(generated_images_np)
        fid_score = calculate_fid(real_embeddings, generated_embeddings)
        fid_scores.append(fid_score)

        # Calculate IS
        is_mean, is_std = calculate_is(generated_images_np)
        is_scores.append((is_mean, is_std))

        print(f'Epoch {epoch + 1} - FID: {fid_score}, IS: Mean -> {is_mean} STD -> {is_std}')

    # Plot FID and IS
    plot_metrics(fid_scores, is_scores)

    # Save the final models after training in the current directory
    generator_g.save(os.path.join(current_dir, 'generator_g_final.h5'))
    generator_f.save(os.path.join(current_dir, 'generator_f_final.h5'))
    print("Final model saved after training.")

# Example usage:
train(training_A, training_B, generator_g, generator_f, epochs, sample_A, testing_A, testing_B)
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-50-d3274f3c8612> in <cell line: 2>()
      1 # Example usage:
----> 2 train(training_A, training_B, generator_g, generator_f, epochs, sample_A, testing_A, testing_B)

2 frames
/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   5981 def raise_from_not_ok_status(e, name) -> NoReturn:
   5982   e.message += (" name: " + str(name if name is not None else ""))
-> 5983   raise core._status_to_exception(e) from None  # pylint: disable=protected-access
   5984 
   5985 

InvalidArgumentError: Exception encountered when calling Concatenate.call().

{{function_node __wrapped__ConcatV2_N_2_device_/job:localhost/replica:0/task:0/device:GPU:0}} ConcatOp : Dimension 1 in both shapes must be equal: shape[0] = [1,6,4,512] vs. shape[1] = [1,5,3,512] [Op:ConcatV2] name: concat

Arguments received by Concatenate.call():
  • inputs=['tf.Tensor(shape=(1, 6, 4, 512), dtype=float32)', 'tf.Tensor(shape=(1, 5, 3, 512), dtype=float32)']

Hi @mdiktushar, Generally this error occurs when you try to concatenate arrays having different dimensions along the specified axis. make sure that the dimensions of the arrays you are trying to concatenate have the same dimensions along the specified axis. Thank You.