Hi everyone, I’m new to the Tensorflow Keras API, and thought I would use it with the tensorflow-metal plugin from Apple to train a custom MobileNetV3Small model on my M1 Pro MacBook for the task of image classification. This is for my app DeTeXt, that classifies drawings into LaTeX symbols. Currently I’m using a MobileNetV2 model that I had trained on a GPU cluster using the PyTorch API (code here).
Here is the code I use to train my custom network from scratch on the images I have:
import tensorflow as tf
import pdb
EPOCHS = 5
BATCH_SIZE = 128
LEARNING_RATE = 0.003
SEED=1220
if __name__ == '__main__':
# Load train and validation data
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
'/Volumes/detext/drawings/',
color_mode="grayscale",
seed=SEED,
batch_size=BATCH_SIZE,
labels='inferred',
label_mode='int',
image_size=(200,300),
validation_split=0.1,
subset='training')
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
'/Volumes/detext/drawings/',
color_mode="grayscale",
seed=SEED,
batch_size=BATCH_SIZE,
labels='inferred',
label_mode='int',
image_size=(200,300),
validation_split=0.1,
subset='validation')
# Get the class names
class_names = train_ds.class_names
num_classes = len(class_names)
# Create model
model = tf.keras.applications.MobileNetV3Small(
input_shape=(200,300,1), alpha=1.0, minimalistic=False,
include_top=True, weights=None, input_tensor=None, classes=num_classes,
pooling=None, dropout_rate=0.2, classifier_activation="softmax",
include_preprocessing=True)
# Compile model
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
# Training
model.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
model.save('./saved_model3/')
While the training runs smooth and fast with the metal plugin, the validation accuracy is very low after 5 epochs, and I suspect it is either predicting the same class every time, or there is an error somewhere in my setup above. I have tried rescaling the inputs myself (and removing rescaling layer from model), but no matter what I try, the validation accuracy it outputs is really low. Here is the output (warnings and all) after 2 epochs:
Found 210454 files belonging to 1098 classes.
Using 189409 files for training.
Metal device set to: Apple M1 Pro
systemMemory: 32.00 GB
maxCacheSize: 10.67 GB
2021-12-16 10:02:46.369476: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2021-12-16 10:02:46.369603: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
Found 210454 files belonging to 1098 classes.
Using 21045 files for validation.
Epoch 1/2
2021-12-16 10:02:50.610564: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
2021-12-16 10:02:50.619328: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2021-12-16 10:02:50.619628: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
1480/1480 [==============================] - ETA: 0s - loss: 1.7621 - sparse_categorical_accuracy: 0.57022021-12-16 10:12:58.720162: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
1480/1480 [==============================] - 626s 422ms/step - loss: 1.7621 - sparse_categorical_accuracy: 0.5702 - val_loss: 9.5837 - val_sparse_categorical_accuracy: 0.0052
Epoch 2/2
1480/1480 [==============================] - 622s 420ms/step - loss: 1.0791 - sparse_categorical_accuracy: 0.6758 - val_loss: 7.3651 - val_sparse_categorical_accuracy: 0.0423
2021-12-16 10:23:40.260143: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
/Users/venkat/miniforge3/envs/tf-metal/lib/python3.9/site-packages/keras/utils/generic_utils.py:494: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
warnings.warn('Custom mask layers require a config and must override '
For reference, I was getting validation micro-F1 (which is same as accuracy) of over 60% with MobilenetV2 in PyTorch. Anyone have any idea what I’m doing wrong here?