Spaces:
Runtime error
Runtime error
🎨 format
Browse filesSigned-off-by: peter szemraj <peterszemraj@gmail.com>
- converse.py +12 -11
converse.py
CHANGED
|
@@ -5,7 +5,10 @@
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import logging
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
| 9 |
import pprint as pp
|
| 10 |
import time
|
| 11 |
|
|
@@ -13,6 +16,7 @@ from grammar_improve import remove_trailing_punctuation
|
|
| 13 |
|
| 14 |
from constrained_generation import constrained_generation
|
| 15 |
|
|
|
|
| 16 |
def discussion(
|
| 17 |
prompt_text: str,
|
| 18 |
speaker: str,
|
|
@@ -65,10 +69,9 @@ def discussion(
|
|
| 65 |
if verbose:
|
| 66 |
print("overall prompt:\n")
|
| 67 |
pp.pprint(this_prompt, indent=4)
|
| 68 |
-
|
| 69 |
-
print("\n... generating...")
|
| 70 |
if constrained_beam_search:
|
| 71 |
-
logging.info("using constrained beam search")
|
| 72 |
response = constrained_generation(
|
| 73 |
prompt=this_prompt,
|
| 74 |
pipeline=pipeline,
|
|
@@ -85,15 +88,13 @@ def discussion(
|
|
| 85 |
|
| 86 |
bot_dialogue = consolidate_texts(
|
| 87 |
name_resp=responder,
|
| 88 |
-
model_resp=response.split(
|
| 89 |
-
"\n"
|
| 90 |
-
),
|
| 91 |
name_spk=speaker,
|
| 92 |
verbose=verbose,
|
| 93 |
print_debug=True,
|
| 94 |
)
|
| 95 |
else:
|
| 96 |
-
logging.info("using sampling")
|
| 97 |
bot_dialogue = gen_response(
|
| 98 |
this_prompt,
|
| 99 |
pipeline,
|
|
@@ -140,15 +141,15 @@ def gen_response(
|
|
| 140 |
speaker: str,
|
| 141 |
responder: str,
|
| 142 |
timeout=45,
|
| 143 |
-
min_length=
|
| 144 |
max_length=48,
|
| 145 |
top_p=0.95,
|
| 146 |
top_k=20,
|
| 147 |
temperature=0.5,
|
| 148 |
full_text=False,
|
| 149 |
num_return_sequences=1,
|
| 150 |
-
length_penalty:float=0.8,
|
| 151 |
-
repetition_penalty:float=3.5,
|
| 152 |
no_repeat_ngram_size=2,
|
| 153 |
device=-1,
|
| 154 |
verbose=False,
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import logging
|
| 8 |
+
|
| 9 |
+
logging.basicConfig(
|
| 10 |
+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 11 |
+
)
|
| 12 |
import pprint as pp
|
| 13 |
import time
|
| 14 |
|
|
|
|
| 16 |
|
| 17 |
from constrained_generation import constrained_generation
|
| 18 |
|
| 19 |
+
|
| 20 |
def discussion(
|
| 21 |
prompt_text: str,
|
| 22 |
speaker: str,
|
|
|
|
| 69 |
if verbose:
|
| 70 |
print("overall prompt:\n")
|
| 71 |
pp.pprint(this_prompt, indent=4)
|
| 72 |
+
|
|
|
|
| 73 |
if constrained_beam_search:
|
| 74 |
+
logging.info("generating using constrained beam search ...")
|
| 75 |
response = constrained_generation(
|
| 76 |
prompt=this_prompt,
|
| 77 |
pipeline=pipeline,
|
|
|
|
| 88 |
|
| 89 |
bot_dialogue = consolidate_texts(
|
| 90 |
name_resp=responder,
|
| 91 |
+
model_resp=response.split("\n"),
|
|
|
|
|
|
|
| 92 |
name_spk=speaker,
|
| 93 |
verbose=verbose,
|
| 94 |
print_debug=True,
|
| 95 |
)
|
| 96 |
else:
|
| 97 |
+
logging.info("generating using sampling ...")
|
| 98 |
bot_dialogue = gen_response(
|
| 99 |
this_prompt,
|
| 100 |
pipeline,
|
|
|
|
| 141 |
speaker: str,
|
| 142 |
responder: str,
|
| 143 |
timeout=45,
|
| 144 |
+
min_length=12,
|
| 145 |
max_length=48,
|
| 146 |
top_p=0.95,
|
| 147 |
top_k=20,
|
| 148 |
temperature=0.5,
|
| 149 |
full_text=False,
|
| 150 |
num_return_sequences=1,
|
| 151 |
+
length_penalty: float = 0.8,
|
| 152 |
+
repetition_penalty: float = 3.5,
|
| 153 |
no_repeat_ngram_size=2,
|
| 154 |
device=-1,
|
| 155 |
verbose=False,
|