Hey there!
I want to extend the tf.Tensor class, but neither of the following options work. I do want to specifically extend it though, and not store a tf.tensor instance as an attribute in my own class!
- option: extend tf.Tensor:
class MyTFTensor(tf.Tensor):
@classmethod
def _from_native(cls, value: tf.Tensor):
value.__class__ = cls
return value
y = MyTFTensor._from_native(value=tf.zeros((3, 224, 224))
Fails with:
/var/folders/kb/yxxdttyj4qzcm447np5p22kw0000gp/T/ipykernel_94703/515175960.py in _from_native(cls, value)
4 @classmethod
5 def _from_native(cls, value: tf.Tensor):
----> 6 value.__class__ = cls
7 return value
TypeError: __class__ assignment: 'MyTFTensor' object layout differs from 'tensorflow.python.framework.ops.EagerTensor'
- Option: extend EagerTensor
from tensorflow.python.framework.ops import EagerTensor
class MyTFTensor(EagerTensor):
@classmethod
def _from_native(cls, value: tf.Tensor):
value.__class__ = cls
return value
Fails with:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/var/folders/kb/yxxdttyj4qzcm447np5p22kw0000gp/T/ipykernel_94703/3632871733.py in <cell line: 2>()
1 from tensorflow.python.framework.ops import EagerTensor
----> 2 class MyTFTensor(EagerTensor):
3
4 @classmethod
5 def _from_native(cls, value: tf.Tensor):
TypeError: type 'tensorflow.python.framework.ops.EagerTensor' is not an acceptable base type
Does anyone have a solution for this?