Hello! I’m trying to replace all usage of matrix multiplication. For the purposes of this simple example and verification, I’m just implementing something that will result in zeroes whenever tf.matmul
is used.
I’m currently using TF 2.7.0 on MacOS, Python 3.9.0 for development. I’ve based the below code off of Типы расширений | TensorFlow Core
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
class MyCustomTensor(tf.experimental.BatchableExtensionType):
# Simple custom tensor, does not do anything special
__name__ = 'replace.tf.MyCustomTensor'
values: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_shape(self):
return MyCustomTensor.Spec(tf.TensorSpec(shape, self.values.dtype))
def convert_to_custom_tensor(x):
if isinstance(x, MyCustomTensor):
return x
else:
return MyCustomTensor(x)
@tf.experimental.dispatch_for_unary_elementwise_apis(MyCustomTensor)
def unary_elementwise_op_handler(op, x):
return MyCustomTensor(op(x.values))
@tf.experimental.dispatch_for_binary_elementwise_apis(Union[MyCustomTensor, tf.Tensor], Union[MyCustomTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
x = convert_to_custom_tensor(x)
y = convert_to_custom_tensor(y)
return MyCustomTensor(op(x.values, y.values))
@tf.experimental.dispatch_for_api(tf.matmul)
def custom_matmul(a: MyCustomTensor, b,
transpose_a=False, transpose_b=False,
adjoint_a=False, adjoint_b=False,
a_is_sparse=False, b_is_sparse=False,
output_type=None, name=None):
if isinstance(a, MyCustomTensor):
a = tf.zeros(a.shape)
if isinstance(b, MyCustomTensor):
b = tf.zeros(b.shape)
tf.print("Matmul replaced!", output_stream=sys.stdout)
return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
adjoint_b, a_is_sparse, b_is_sparse, output_type)
I then create a simple model:
dense_input_spec = MyCustomTensor.Spec([1, 2], tf.float32)
dense_model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=dense_input_spec),
tf.keras.layers.Dense(16, activation="relu", use_bias=False),
tf.keras.layers.Dense(1, use_bias=False)])
dense_model(MyCustomTensor(np.ones((1,2))))
dense_model(np.ones((1,2)))
gives me a nonzero value whereas dense_model(MyCustomTensor(np.ones((1,2))))
gives me 0 as expected. Awesome!
However, when I try:
conv_input_spec = MyCustomTensor.Spec([1,224, 224, 3], tf.float32)
conv_model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=conv_input_spec),
tf.keras.layers.Conv2D(3, 3, use_bias=False)])
I get the following error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/var/folders/d7/c4dm46h53_d9vhx9_q_y12fm0000gn/T/ipykernel_9694/354279810.py in <module>
1 conv_input_spec = MyCustomTensor.Spec([1,224, 224, 3], tf.float32)
----> 2 conv_model = tf.keras.Sequential([
3 tf.keras.layers.Input(type_spec=conv_input_spec),
4 tf.keras.layers.Conv2D(3, 3, use_bias=False),
5 ]
~/.pyenv/versions/3.9.0/envs/idiom-ml-tf27/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
528 self._self_setattr_tracking = False # pylint: disable=protected-access
529 try:
--> 530 result = method(self, *args, **kwargs)
531 finally:
532 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
~/.pyenv/versions/3.9.0/envs/idiom-ml-tf27/lib/python3.9/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
~/.pyenv/versions/3.9.0/envs/idiom-ml-tf27/lib/python3.9/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape, allow_broadcast)
547 str_values = [compat.as_bytes(x) for x in proto_values]
548 except TypeError:
--> 549 raise TypeError(f"Failed to convert elements of {values} to Tensor. "
550 "Consider casting elements to a supported type. See "
551 "https://www.tensorflow.org/api_docs/python/tf/dtypes "
TypeError: Exception encountered when calling layer "conv2d" (type Conv2D).
Failed to convert elements of MyCustomTensor(values=<tf.Tensor 'Placeholder:0' shape=(1, 224, 224, 3) dtype=float32>) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.
Call arguments received:
• inputs=MyCustomTensor(values=<tf.Tensor 'Placeholder:0' shape=(1, 224, 224, 3) dtype=float32>)
Turns out this requires dispatching tf.nn.convolution
, something like:
def custom_convolution(input: MyCustomTensor,
filters,
strides=None,
padding="VALID",
data_format=None,
dilations=None,
name=None):
tf.print("Conv replaced!")
input = tf.zeros(input.shape)
return tf.nn.convolution(input,
filters,
strides=strides,
padding=padding,
data_format=data_format,
dilations=dilations,
name=name)
(Really silly and simple, I know, but this was just to test replacing ops) Then I get expected results of zeroes. Looking through the source code, it seems that Conv
operations are dispatched to C++ code? Do I really need to dispatch for all Conv
ops? Is there any way to replace all matrix multiplication for all ops?
Thank you in advance for any pointers!