Spaces:
Runtime error
Runtime error
File size: 3,198 Bytes
6600352 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import spacy
# download the En model firstly by:python -m spacy download en_core_web_sm
nlp = spacy.load("en_core_web_sm")
# This is only a test method for color prompt, like "a red car and a blue bench"
def split_prompt(prompt):
# Step 1: parse
doc = nlp(prompt)
chunks = list(doc.noun_chunks)
# If only one chunk and there's "and", try splitting on " and "
if len(chunks) < 2 and " and " in prompt:
sub_prompts = [seg.strip() for seg in prompt.split(" and ")]
new_chunks = []
for sub in sub_prompts:
sub_doc = nlp(sub)
sub_noun_chunks = list(sub_doc.noun_chunks)
if not sub_noun_chunks:
new_chunks.append(sub_doc[:])
else:
new_chunks.extend(sub_noun_chunks)
chunks = new_chunks
if len(chunks) < 2:
sps = [prompt]
nps = [prompt]
return (sps, nps)
# Step 2: split on "of" within each chunk
final_chunks = []
for chunk in chunks:
chunk_tokens = list(chunk)
of_prep_token = None
for token in chunk_tokens:
if token.lemma_ == "of" and token.dep_ == "prep":
of_prep_token = token
break
if of_prep_token:
of_subtree = list(of_prep_token.subtree)
chunk1_tokens = [t for t in chunk_tokens if t not in of_subtree]
pobj_tokens = [t for t in of_prep_token.children if t.dep_ == "pobj"]
if pobj_tokens:
pobj_root = pobj_tokens[0]
chunk2_tokens = list(pobj_root.subtree)
else:
chunk2_tokens = []
if chunk1_tokens and chunk2_tokens:
doc_obj = chunk1_tokens[0].doc
chunk1_span = doc_obj[chunk1_tokens[0].i : chunk1_tokens[-1].i + 1]
chunk2_span = doc_obj[chunk2_tokens[0].i : chunk2_tokens[-1].i + 1]
final_chunks.append(chunk1_span)
final_chunks.append(chunk2_span)
continue
else:
final_chunks.append(chunk)
else:
final_chunks.append(chunk)
if len(final_chunks) < 2:
sps = [prompt]
nps = [prompt]
return (sps, nps)
def process_chunk(ch):
full = ch.text
det = ""
for token in ch:
if token.dep_ == "det":
det = token.text
break
head = ch.root.text
stripped = (det + " " if det else "") + head
return (full, stripped)
processed = [process_chunk(ch) for ch in final_chunks]
variants = []
for i in range(len(processed)):
parts = []
for j, (full, stripped) in enumerate(processed):
if j == i:
parts.append(full)
else:
parts.append(stripped)
variants.append(" and ".join(parts))
# simple_version = " and ".join(ch.root.text for ch in final_chunks)
simple_version = " and ".join(stripped for full, stripped in processed)
subjects = [ch.root.text for ch in final_chunks]
sps = [simple_version]
sps.extend(variants)
nps = subjects
return (sps, nps)
|