Hello there I would like to load image patches in my tensorflow model. I followed pix2pix loading example and adapted it to my model. But I cannot load the patches. Thanks for the assistance. Here is my code:
def load2(image_file):
image_ts = tf.io.read_file(image_file)
image_ts = tf.image.decode_png(image_ts)
# Calculate the center of the image
w_center = 358
# Split the image into two parts
input_image = image_ts[:, w_center:, :]
real_image = image_ts[:, :w_center, :]
input_image = tf.cast(input_image, tf.float32)
real_image = tf.cast(real_image, tf.float32)
input_image, real_image = resize(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)
input_patches = []
gt_patches = []
# Extract patches using tf.image.extract_patches
input_patches = tf.image.extract_patches(input_image[tf.newaxis, ...],sizes=[1, patch_size[0], patch_size[1], 1],
strides=[1, stride[0], stride[1], 1],rates = [1, 1, 1, 1])
gt_patches = tf.image.extract_patches(real_image[tf.newaxis, ...],sizes=[1, patch_size[0], patch_size[1], 1],
strides=[1, stride[0], stride[1], 1],rates = [1, 1, 1, 1])
# Reshape patches to get a list of individual patches
input_patches = tf.reshape(input_patches, [-1, patch_size[0], patch_size[1], 3])
gt_patches = tf.reshape(gt_patches, [-1, patch_size[0], patch_size[1], 3])
return input_patches, gt_patches, image_file
LOAD TRAIN
def load_image_train(image_file):
input_image=
real_image=
input_image, real_image, image_file = load2(image_file)
input_image, real_image = random_jitter(input_image, real_image)
input_image, real_image = normalize_images(input_image, real_image)
return input_image, real_image, image_file
input_directory_train = â/home/rafael/Ărea de Trabalho/Dataset-PNGâ
input_directory_test = â/home/rafael/Ărea de Trabalho/Dataset-PNG/Testâ
input_directory_val = â/home/rafael/Ărea de Trabalho/Dataset-PNG/Valâ
Use glob to get a list of image file paths in the input directory
image_files = glob.glob(input_directory_train + â/.png") # Change the file extension as needed
image_files_test = glob.glob(input_directory_test + "/.pngâ) # Change the file extension as needed
image_files_val = glob.glob(input_directory_val + â/*.pngâ) # Change the file extension as needed
Create a dataset from the list of image files
dataset = tf.data.Dataset.from_tensor_slices(image_files)
dataset = dataset.map(load_image_train,num_parallel_calls=tf.data.AUTOTUNE)
Define batch size and other data pipeline operations (e.g., shuffle, repeat, etc.) batch_size = 32
dataset = dataset.shuffle(buffer_size=1000) # Shuffle the data
dataset = dataset.batch(BATCH_SIZE) # Batch the data
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) # Prefetch for better performanc
Skip the first N elements and take the next M elements
N = 0
M = 200
dataset = dataset.skip(N).take(M)