-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathgpt2_transformers_example.py
More file actions
71 lines (61 loc) · 2.51 KB
/
gpt2_transformers_example.py
File metadata and controls
71 lines (61 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from asyncore import loop
import torch
import numpy as np
from eet import EETGPT2Model
from transformers import GPT2Model
import time
using_half = True
prompt_seq_len = 512
batch = 5
max_seq_len = 1024
loop = 10
def main():
# 输入数据构造,实际业务输入应该是tokens
input = np.random.randint(1000,9000,prompt_seq_len * batch,dtype="int64")
inputs = np.random.randint(1000,9000,1 * batch,dtype="int64")
# prompt context
input_full_decoder = torch.from_numpy(input).long().reshape(batch, prompt_seq_len).cuda()
input_inc_decoder = torch.from_numpy(inputs).long().reshape(batch, 1).cuda()
data_type = torch.float32
if using_half:
data_type = torch.float16
# load model
eet_model = EETGPT2Model.from_pretrained('gpt2',max_batch = batch, full_seq_len = max_seq_len,data_type = data_type)
torch_model = GPT2Model.from_pretrained('gpt2').cuda()
if using_half:
torch_model =torch_model.half()
attention_mask = None
# prediction
torch.cuda.synchronize()
t1 = time.perf_counter()
'''
first_pass 用于判断生成任务时是否是第一步,也就是是否是在做提示词的推理。true代表在做提示词的推理,false代表在做生成推理
由于eet不会返回past_key_value,前一步的信息全部在内部做了保存,所以没法通过past_key_value做判断,故增加此参数。
'''
for j in range(loop):
input_ids = input_full_decoder
first_pass = True
for i in range(max_seq_len-prompt_seq_len):
res_eet = eet_model(input_ids,first_pass= first_pass,attention_mask = attention_mask)
if first_pass:
first_pass = False
input_ids = input_inc_decoder
torch.cuda.synchronize()
t2 = time.perf_counter()
print('Time for EET : ', t2 - t1)
torch.cuda.synchronize()
t3 = time.perf_counter()
for j in range(loop):
input_ids = input_full_decoder
past_key_values = None
for i in range(max_seq_len-prompt_seq_len):
with torch.no_grad():
res_torch = torch_model(input_ids,past_key_values = past_key_values,attention_mask = attention_mask)
past_key_values = res_torch.past_key_values
input_ids = input_inc_decoder
torch.cuda.synchronize()
t4 = time.perf_counter()
print('Time for torch : ', t4 - t3)
print('SpeedUp is : ', (t4 - t3)/(t2- t1))
if __name__ == '__main__':
main()