-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclassification.py
More file actions
127 lines (96 loc) · 3.63 KB
/
classification.py
File metadata and controls
127 lines (96 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import re
import nltk
from chat_tools.patterns import patterns as pt
patterns = pt
from variables import suicide_classes
from transformers import pipeline
classifier = pipeline(task='zero-shot-classification', model='facebook/bart-large-mnli')
def classify(text):
"""
Classifies the given text using predefined patterns (regex) and returns the category.
Parameters:
- text: The input text to classify.
Returns:
- category: The category of the input text.
"""
for pattern, category in patterns.items():
if re.search(pattern, text, re.IGNORECASE):
return category
return 'unknown'
def zsc(user_input):
"""
Performs zero-shot classification on the user input to detect the potential topics it might belong to.
Parameters:
- user_input: The input text to classify.
Returns:
- classification: The result of zero-shot classification on the user input.
"""
if(len(user_input) < 150):
score = 0
tokens = nltk.word_tokenize(user_input)
tags = nltk.pos_tag(tokens)
for tag in tags:
if(tag[-1] == 'VBG' and classifier(user_input, candidate_labels=["death"])['scores'][0] > 0.5): # gerund, verb, present participle & in the context of death
score += 1
if score >= len(user_input) / 2:
return 'immediate-help'
return classifier(user_input,
candidate_labels=["suicide", "death", "self-harm", "anger", "sad", "guilt", "fear", "happy"],
multi_label=True
)
def general_zsc(user_input, cand_labels):
"""
General purpose zero-shot classification.
Args:
user_input (str): User input message.
cand_labels (list): List of candidate labels for zero-shot classification.
Returns:
float: Zero-shot classification score.
"""
if(len(user_input) < 150):
score = 0
tokens = nltk.word_tokenize(user_input)
tags = nltk.pos_tag(tokens)
for tag in tags:
if(tag[-1] == 'VBG' and classifier(user_input, candidate_labels=["death"])['scores'][0] > 0.5): # gerund, verb, present participle & in the context of death
score += 1
if score >= len(user_input) / 2:
return 'immediate-help'
return classifier(user_input,
candidate_labels=cand_labels,
multi_label=True
)['scores'][0]
def cause_zsc(user_input, feeling):
"""
Performs zero-shot classification on the user input to detect the potential causes related to a specific feeling.
Parameters:
- user_input: The input text to classify.
- feeling: The feeling associated with the user input.
Returns:
- classification: The result of zero-shot classification on the user input.
"""
cand_labels = []
if(feeling == 'sad'):
cand_labels = ['death', 'life-difficulties', 'unknown']
elif(feeling == 'anger'):
cand_labels = ['friends', 'family', 'social']
elif(feeling == 'guilt'):
cand_labels = ['unmoral', 'fail']
elif(feeling == 'fear'):
cand_labels = ['illegal', 'evil-power']
elif(feeling in suicide_classes):
cand_labels = ['hopeless', 'overwhelmed']
return classifier(user_input, candidate_labels=cand_labels, multi_label=True)
def get_class(msg):
"""
Classify the user input usin zero-shot classification (if not classified using regex)
Args:
msg (str): User input message.
Returns:
str: Classified class of the user input.
"""
_class = classify(user_input)
if(_class == 'unknown'):
return zsc(user_input)
else:
return _class