-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathvit_transformers_example.py
More file actions
64 lines (51 loc) · 1.88 KB
/
vit_transformers_example.py
File metadata and controls
64 lines (51 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import numpy as np
from torch.nn.parameter import Parameter
from eet.transformers.modeling_vit import EETViTModel
from transformers import ViTFeatureExtractor, ViTModel
from PIL import Image
import requests
import time
using_half = True
batch_size = 20
loop = 100
def main():
torch.set_grad_enabled(False)
# image input
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw)
# feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
# inputs = feature_extractor(images=image, return_tensors="pt")
# build pretrained model
data_type = torch.float32
ts_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').cuda()
if using_half:
ts_model = ts_model.half()
data_type = torch.float16
eet_model = EETViTModel.from_pretrained(
'google/vit-base-patch16-224-in21k', max_batch=batch_size, data_type=data_type)
# dummy input
dummy_input = torch.from_numpy(np.random.random((batch_size, 3, 224, 224))).to(data_type).cuda()
input_states = dummy_input
attention_mask = None
# Inference using transformers
# The first inference takes a long time
for i in range(loop):
res_ts = ts_model(input_states, attention_mask)
t3 = time.perf_counter()
with torch.no_grad():
for i in range(loop):
res_ts = ts_model(input_states, attention_mask)
t4 = time.perf_counter()
time_ts = t4 - t3
# Inference using EET
t1 = time.perf_counter()
for i in range(loop):
res_eet = eet_model(input_states, attention_mask=attention_mask)
t2 = time.perf_counter()
time_eet = t2 - t1
print('Time for EET: ', time_eet)
print('Time for Transformers: ', time_ts)
print('SpeedUp is ', time_ts / time_eet)
if __name__ == '__main__':
main()