Hello everyone,
As you may already know, NVIDIA has recently published (20 June 2022) its paper, GCViT: Global Context Vision Transformer which outperforms ConvNeXt and SwinTransformer.
I’ve implemented this model using TensorFlow 2.0 and created an open-source library gcvit-tf. I’ve also made a notebook explaining it. I hope this helps the community.
I’m also planning to publish it to TFHub, it would be really helpful if I could get some directions regarding this…
Here are links to the project,
- Github: gcvit-tf
- Paper Explanation: GUIE: Global Context ViT (GCViT) in Kaggle
- Live-Demo: Gradio Demo in HuggingFace Space.
Features of gcvit-tf:
- This library loads ImageNet weights from the official repo.
- Also, it has
timm
like features such asforward_features
,forward_head
, andreset_classifier
which might come in handy. - It can be used in both GPU and TPU.
Supported Models
The official codebase had some issue which has been fixed recently (27 July 2022). Here’s the result of ported weights on ImageNetV2-Test data,
Model | Acc@1 | Acc@5 | #Params |
---|---|---|---|
GCViT-XXTiny | 63 | 85 | 12M |
GCViT-XTiny | 66 | 87 | 20M |
GCViT-Tiny | 69 | 89 | 28M |
GCViT-Small | 69 | 89 | 51M |
GCViT-Base | 71 | 90 | 90M |
Usage
Install Library
pip install gcvit
Load model using the following codes,
from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True)
Simple code to check model’s prediction,
from skimage.data import chelsea
img = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat
img = tf.image.resize(img, (224, 224))[None,] # resize & create batch
pred = model(img).numpy()
print(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])
Prediction:
[('n02124075', 'Egyptian_cat', 0.9194835),
('n02123045', 'tabby', 0.009686623),
('n02123159', 'tiger_cat', 0.0061576385),
('n02127052', 'lynx', 0.0011503297),
('n02883205', 'bow_tie', 0.00042479983)]
For feature extraction:
model = GCViTTiny(pretrain=True) # when pretrain=True, num_classes must be 1000
model.reset_classifier(num_classes=0, head_act=None)
feature = model(img)
print(feature.shape)
Feature:
(None, 512)
For feature map:
model = GCViTTiny(pretrain=True) # when pretrain=True, num_classes must be 1000
feature = model.forward_features(img)
print(feature.shape)
Feature map:
(None, 7, 7, 512)
Note: Official repo has some issues which had resulted in a performance drop. It got updated recently 27 July 2022. But still, one issue persists. Hence, ImageNet weights may get updated in the future.