Spaces:
Runtime error
Runtime error
Peter
commited on
Commit
·
0b3d061
1
Parent(s):
235585a
🐛 fix input len bug
Browse filesSigned-off-by: Peter <74869040+pszemraj@users.noreply.github.com>
- app.py +8 -5
- converse.py +12 -8
- grammar_improve.py +5 -3
app.py
CHANGED
|
@@ -101,11 +101,13 @@ def ask_gpt(
|
|
| 101 |
st = time.perf_counter()
|
| 102 |
prompt = clean(message) # clean user input
|
| 103 |
prompt = prompt.strip() # get rid of any extra whitespace
|
| 104 |
-
in_len = len(prompt)
|
| 105 |
if in_len > 512:
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
|
| 110 |
resp = discussion(
|
| 111 |
prompt_text=prompt,
|
|
@@ -115,7 +117,8 @@ def ask_gpt(
|
|
| 115 |
top_p=top_p,
|
| 116 |
top_k=top_k,
|
| 117 |
temperature=temperature,
|
| 118 |
-
max_length=
|
|
|
|
| 119 |
)
|
| 120 |
gpt_et = time.perf_counter()
|
| 121 |
gpt_rt = round(gpt_et - st, 2)
|
|
|
|
| 101 |
st = time.perf_counter()
|
| 102 |
prompt = clean(message) # clean user input
|
| 103 |
prompt = prompt.strip() # get rid of any extra whitespace
|
| 104 |
+
in_len = len(chat_pipe.tokenizer(prompt).input_ids)
|
| 105 |
if in_len > 512:
|
| 106 |
+
# truncate to last 512 tokens
|
| 107 |
+
tokens = chat_pipe.tokenizer(prompt).input_ids
|
| 108 |
+
trunc_tokens = tokens[-512:]
|
| 109 |
+
prompt = chat_pipe.tokenizer.decode(trunc_tokens)
|
| 110 |
+
print(f"truncated prompt to {len(trunc_tokens)} tokens, input length: {in_len}")
|
| 111 |
|
| 112 |
resp = discussion(
|
| 113 |
prompt_text=prompt,
|
|
|
|
| 117 |
top_p=top_p,
|
| 118 |
top_k=top_k,
|
| 119 |
temperature=temperature,
|
| 120 |
+
max_length=max_length,
|
| 121 |
+
min_length=min_length,
|
| 122 |
)
|
| 123 |
gpt_et = time.perf_counter()
|
| 124 |
gpt_rt = round(gpt_et - st, 2)
|
converse.py
CHANGED
|
@@ -17,7 +17,8 @@ def discussion(
|
|
| 17 |
responder: str,
|
| 18 |
pipeline,
|
| 19 |
timeout=45,
|
| 20 |
-
|
|
|
|
| 21 |
top_p=0.95,
|
| 22 |
top_k=50,
|
| 23 |
temperature=0.7,
|
|
@@ -104,7 +105,8 @@ def gen_response(
|
|
| 104 |
speaker: str,
|
| 105 |
responder: str,
|
| 106 |
timeout=45,
|
| 107 |
-
|
|
|
|
| 108 |
top_p=0.95,
|
| 109 |
top_k=50,
|
| 110 |
temperature=0.7,
|
|
@@ -125,7 +127,8 @@ def gen_response(
|
|
| 125 |
responder : str, the name of the person who is responding to the prompt
|
| 126 |
pipeline : transformers.Pipeline, the pipeline to use for generating the response
|
| 127 |
timeout : int, optional, the number of seconds to wait before timing out, by default 45
|
| 128 |
-
|
|
|
|
| 129 |
top_p : float, optional, the top probability to use for sampling, defaults to 0.95
|
| 130 |
top_k : int, optional, the top k to use for sampling, defaults to 50
|
| 131 |
temperature : float, optional, the temperature to use for sampling, defaults to 0.7
|
|
@@ -139,15 +142,16 @@ def gen_response(
|
|
| 139 |
str, the generated text
|
| 140 |
|
| 141 |
"""
|
| 142 |
-
|
| 143 |
-
if max_length > 1024:
|
| 144 |
-
max_length = 1024
|
| 145 |
-
print("max_length
|
| 146 |
st = time.perf_counter()
|
| 147 |
|
| 148 |
response = pipeline(
|
| 149 |
query,
|
| 150 |
-
|
|
|
|
| 151 |
temperature=temperature,
|
| 152 |
top_k=top_k,
|
| 153 |
top_p=top_p,
|
|
|
|
| 17 |
responder: str,
|
| 18 |
pipeline,
|
| 19 |
timeout=45,
|
| 20 |
+
min_length=4,
|
| 21 |
+
max_length=64,
|
| 22 |
top_p=0.95,
|
| 23 |
top_k=50,
|
| 24 |
temperature=0.7,
|
|
|
|
| 105 |
speaker: str,
|
| 106 |
responder: str,
|
| 107 |
timeout=45,
|
| 108 |
+
min_length=4,
|
| 109 |
+
max_length=64,
|
| 110 |
top_p=0.95,
|
| 111 |
top_k=50,
|
| 112 |
temperature=0.7,
|
|
|
|
| 127 |
responder : str, the name of the person who is responding to the prompt
|
| 128 |
pipeline : transformers.Pipeline, the pipeline to use for generating the response
|
| 129 |
timeout : int, optional, the number of seconds to wait before timing out, by default 45
|
| 130 |
+
min_length : int, optional, the minimum number of tokens to generate, defaults to 4
|
| 131 |
+
max_length : int, optional, the maximum number of tokens to generate, defaults to 64
|
| 132 |
top_p : float, optional, the top probability to use for sampling, defaults to 0.95
|
| 133 |
top_k : int, optional, the top k to use for sampling, defaults to 50
|
| 134 |
temperature : float, optional, the temperature to use for sampling, defaults to 0.7
|
|
|
|
| 142 |
str, the generated text
|
| 143 |
|
| 144 |
"""
|
| 145 |
+
input_len = len(pipeline.tokenizer(query).input_ids)
|
| 146 |
+
if max_length + input_len > 1024:
|
| 147 |
+
max_length = max(1024 - input_len, 8)
|
| 148 |
+
print(f"max_length too large, setting to {max_length}")
|
| 149 |
st = time.perf_counter()
|
| 150 |
|
| 151 |
response = pipeline(
|
| 152 |
query,
|
| 153 |
+
min_length=min_length + input_len,
|
| 154 |
+
max_length=max_length + input_len,
|
| 155 |
temperature=temperature,
|
| 156 |
top_k=top_k,
|
| 157 |
top_p=top_p,
|
grammar_improve.py
CHANGED
|
@@ -137,10 +137,11 @@ def synthesize_grammar(
|
|
| 137 |
"""
|
| 138 |
st = time.perf_counter()
|
| 139 |
input_text = clean(message, lower=False)
|
|
|
|
| 140 |
results = corrector(
|
| 141 |
input_text,
|
| 142 |
-
max_length=int(1.1 *
|
| 143 |
-
min_length=2 if
|
| 144 |
num_beams=num_beams,
|
| 145 |
repetition_penalty=repetition_penalty,
|
| 146 |
length_penalty=length_penalty,
|
|
@@ -479,7 +480,8 @@ def correct_grammar(
|
|
| 479 |
"""
|
| 480 |
st = time.perf_counter()
|
| 481 |
|
| 482 |
-
if len(input_text) <
|
|
|
|
| 483 |
return input_text
|
| 484 |
max_length = min(int(math.ceil(len(input_text) * 1.2)), 128)
|
| 485 |
batch = tokenizer(
|
|
|
|
| 137 |
"""
|
| 138 |
st = time.perf_counter()
|
| 139 |
input_text = clean(message, lower=False)
|
| 140 |
+
input_len = len(corrector.tokenizer(input_text).input_ids)
|
| 141 |
results = corrector(
|
| 142 |
input_text,
|
| 143 |
+
max_length=int(1.1 * input_len),
|
| 144 |
+
min_length=2 if input_len < 64 else int(0.2 * input_len),
|
| 145 |
num_beams=num_beams,
|
| 146 |
repetition_penalty=repetition_penalty,
|
| 147 |
length_penalty=length_penalty,
|
|
|
|
| 480 |
"""
|
| 481 |
st = time.perf_counter()
|
| 482 |
|
| 483 |
+
if len(tokenizer(input_text).input_ids) < 4:
|
| 484 |
+
print(f"input text of {input_text} is too short to be corrected")
|
| 485 |
return input_text
|
| 486 |
max_length = min(int(math.ceil(len(input_text) * 1.2)), 128)
|
| 487 |
batch = tokenizer(
|