-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathdistilbert_transformers_example.py
More file actions
51 lines (43 loc) · 1.54 KB
/
distilbert_transformers_example.py
File metadata and controls
51 lines (43 loc) · 1.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import numpy as np
from eet import EETDistilBertModel
from transformers import DistilBertModel
import time
using_half = True
seq_len = 32
batch = 4
loop = 100
def main():
torch.set_grad_enabled(False)
# Construct the input, in the actual project the input should be tokens
input = np.random.randint(1000, 9000, seq_len * batch, dtype="int64")
input_ids = torch.from_numpy(input).long().reshape(batch, seq_len).cuda()
# load model,eet needs to pass in the maximum batch and data type
data_type = torch.float32
ts_model = DistilBertModel.from_pretrained('distilbert-base-cased-distilled-squad').cuda()
if using_half:
ts_model = ts_model.half()
data_type = torch.float16
eet_model = EETDistilBertModel.from_pretrained('distilbert-base-cased-distilled-squad',max_batch = batch,data_type = data_type)
# inference
attention_mask = None
torch.cuda.synchronize()
t1 = time.perf_counter()
for i in range(loop):
res_eet = eet_model(input_ids, attention_mask=attention_mask)
torch.cuda.synchronize()
t2 = time.perf_counter()
time_eet = t2 - t1
torch.cuda.synchronize()
t3 = time.perf_counter()
with torch.no_grad():
for i in range(loop):
res_ts = ts_model(input_ids, attention_mask=attention_mask)
torch.cuda.synchronize()
t4= time.perf_counter()
time_ts = t4 - t3
print('Time for EET : ', time_eet)
print('Time for Transformers: ', time_ts)
print('SpeedUp is ', time_ts / time_eet)
if __name__ == '__main__':
main()