File size: 3,198 Bytes
6600352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import spacy

# download the En model firstly by:python -m spacy download en_core_web_sm
nlp = spacy.load("en_core_web_sm")

# This is only a test method for color prompt, like "a red car and a blue bench"
def split_prompt(prompt):
    # Step 1: parse
    doc = nlp(prompt)
    chunks = list(doc.noun_chunks)

    # If only one chunk and there's "and", try splitting on " and "
    if len(chunks) < 2 and " and " in prompt:
        sub_prompts = [seg.strip() for seg in prompt.split(" and ")]
        new_chunks = []
        for sub in sub_prompts:
            sub_doc = nlp(sub)
            sub_noun_chunks = list(sub_doc.noun_chunks)
            if not sub_noun_chunks:
                new_chunks.append(sub_doc[:])
            else:
                new_chunks.extend(sub_noun_chunks)
        chunks = new_chunks

    if len(chunks) < 2:
        sps = [prompt]
        nps = [prompt]
        return (sps, nps)

    # Step 2: split on "of" within each chunk
    final_chunks = []
    for chunk in chunks:
        chunk_tokens = list(chunk)
        of_prep_token = None
        for token in chunk_tokens:
            if token.lemma_ == "of" and token.dep_ == "prep":
                of_prep_token = token
                break

        if of_prep_token:
            of_subtree = list(of_prep_token.subtree)
            chunk1_tokens = [t for t in chunk_tokens if t not in of_subtree]
            pobj_tokens = [t for t in of_prep_token.children if t.dep_ == "pobj"]
            if pobj_tokens:
                pobj_root = pobj_tokens[0]
                chunk2_tokens = list(pobj_root.subtree)
            else:
                chunk2_tokens = []

            if chunk1_tokens and chunk2_tokens:
                doc_obj = chunk1_tokens[0].doc
                chunk1_span = doc_obj[chunk1_tokens[0].i : chunk1_tokens[-1].i + 1]
                chunk2_span = doc_obj[chunk2_tokens[0].i : chunk2_tokens[-1].i + 1]
                final_chunks.append(chunk1_span)
                final_chunks.append(chunk2_span)
                continue
            else:
                final_chunks.append(chunk)
        else:
            final_chunks.append(chunk)

    if len(final_chunks) < 2:
        sps = [prompt]
        nps = [prompt]
        return (sps, nps)

    def process_chunk(ch):
        full = ch.text
        det = ""
        for token in ch:
            if token.dep_ == "det":
                det = token.text
                break
        head = ch.root.text
        stripped = (det + " " if det else "") + head
        return (full, stripped)

    processed = [process_chunk(ch) for ch in final_chunks]

    variants = []
    for i in range(len(processed)):
        parts = []
        for j, (full, stripped) in enumerate(processed):
            if j == i:
                parts.append(full)
            else:
                parts.append(stripped)
        variants.append(" and ".join(parts))

    # simple_version = " and ".join(ch.root.text for ch in final_chunks)
    simple_version = " and ".join(stripped for full, stripped in processed)
    subjects = [ch.root.text for ch in final_chunks]

    sps = [simple_version]
    sps.extend(variants)
    nps = subjects

    return (sps, nps)