Cxxs commited on
Commit
e157e7f
·
1 Parent(s): 30c2304

add input check

Browse files
Files changed (2) hide show
  1. app.py +45 -32
  2. prompt_check.py +35 -0
app.py CHANGED
@@ -12,7 +12,9 @@ import spaces
12
  import torch
13
  from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
14
  from PIL import Image
15
- from transformers import AutoModel, AutoTokenizer
 
 
16
 
17
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
18
 
@@ -126,7 +128,7 @@ def load_models(model_path, enable_compile=False, attention_backend="native"):
126
  os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda"
127
  )
128
 
129
- text_encoder = AutoModel.from_pretrained(
130
  os.path.join(model_path, "text_encoder"),
131
  torch_dtype=torch.bfloat16,
132
  device_map="cuda",
@@ -402,42 +404,53 @@ def generate(
402
  - seed_str: String representation of the seed used for generation
403
  - seed_int: Integer representation of the seed used for generation
404
  """
405
- if pipe is None:
406
- raise gr.Error("Model not loaded.")
407
 
408
- final_prompt = prompt
 
409
 
410
- if enhance:
411
- final_prompt, _ = prompt_enhance(prompt, True)
412
- print(f"Enhanced prompt: {final_prompt}")
413
 
414
- if random_seed:
415
- new_seed = random.randint(1, 1000000)
416
- else:
417
- new_seed = seed if seed != -1 else random.randint(1, 1000000)
418
 
419
- try:
420
- resolution_str = resolution.split(" ")[0]
421
- except:
422
- resolution_str = "1024x1024"
423
-
424
- image = generate_image(
425
- pipe=pipe,
426
- prompt=final_prompt,
427
- resolution=resolution_str,
428
- seed=new_seed,
429
- guidance_scale=0.0,
430
- num_inference_steps=int(steps + 1),
431
- shift=shift,
432
- )
433
 
434
- safety_checker_input = pipe.safety_feature_extractor([image], return_tensors="pt").pixel_values.cuda()
435
- _, has_nsfw_concept = pipe.safety_checker(
436
- images=[torch.zeros(1)], clip_input=safety_checker_input
437
- )
438
- has_nsfw_concept = has_nsfw_concept[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
- if has_nsfw_concept:
441
  image = Image.open("nsfw.png")
442
 
443
  if gallery_images is None:
 
12
  import torch
13
  from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
14
  from PIL import Image
15
+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
16
+
17
+ from prompt_check import is_unsafe_prompt
18
 
19
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
20
 
 
128
  os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda"
129
  )
130
 
131
+ text_encoder = AutoModelForCausalLM.from_pretrained(
132
  os.path.join(model_path, "text_encoder"),
133
  torch_dtype=torch.bfloat16,
134
  device_map="cuda",
 
404
  - seed_str: String representation of the seed used for generation
405
  - seed_int: Integer representation of the seed used for generation
406
  """
 
 
407
 
408
+ class UnsafeContentError(Exception):
409
+ pass
410
 
411
+ try:
412
+ if pipe is None:
413
+ raise gr.Error("Model not loaded.")
414
 
415
+ has_nsfw_concept = is_unsafe_prompt(pipe.text_encoder, pipe.tokenizer, prompt)
416
+ if has_nsfw_concept:
417
+ raise UnsafeContentError("input unsafe")
 
418
 
419
+ final_prompt = prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
+ if enhance:
422
+ final_prompt, _ = prompt_enhance(prompt, True)
423
+ print(f"Enhanced prompt: {final_prompt}")
424
+
425
+ if random_seed:
426
+ new_seed = random.randint(1, 1000000)
427
+ else:
428
+ new_seed = seed if seed != -1 else random.randint(1, 1000000)
429
+
430
+ try:
431
+ resolution_str = resolution.split(" ")[0]
432
+ except:
433
+ resolution_str = "1024x1024"
434
+
435
+ image = generate_image(
436
+ pipe=pipe,
437
+ prompt=final_prompt,
438
+ resolution=resolution_str,
439
+ seed=new_seed,
440
+ guidance_scale=0.0,
441
+ num_inference_steps=int(steps + 1),
442
+ shift=shift,
443
+ )
444
+
445
+ safety_checker_input = pipe.safety_feature_extractor([image], return_tensors="pt").pixel_values.cuda()
446
+ _, has_nsfw_concept = pipe.safety_checker(
447
+ images=[torch.zeros(1)], clip_input=safety_checker_input
448
+ )
449
+ has_nsfw_concept = has_nsfw_concept[0]
450
+ if has_nsfw_concept:
451
+ raise UnsafeContentError("input unsafe")
452
 
453
+ except UnsafeContentError:
454
  image = Image.open("nsfw.png")
455
 
456
  if gallery_images is None:
prompt_check.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def is_unsafe_prompt(model, tokenizer, user_prompt):
2
+ prompt = f"""You are a content safety filter. Your task is to analyze the following user input and determine if it contains unsafe content.
3
+
4
+ Unsafe content includes:
5
+ 1. **Sexually Explicit Material**: Descriptions of nudity or sexual acts.
6
+ 2. **Sensitive Political Content**: References to real-world political figures (e.g., heads of state, government officials), political parties, sensitive geopolitical events, or propaganda.
7
+
8
+ Analyze the user input below.
9
+ If the input falls into EITHER of these categories, answer "yes".
10
+ If the input is safe, answer "no".
11
+
12
+ User Input: "{user_prompt}"
13
+
14
+ Answer (only yes or no):"""
15
+ messages = [
16
+ {"role": "user", "content": prompt}
17
+ ]
18
+ text = tokenizer.apply_chat_template(
19
+ messages,
20
+ tokenize=False,
21
+ add_generation_prompt=True,
22
+ enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
23
+ )
24
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
25
+
26
+ # conduct text completion
27
+ generated_ids = model.generate(
28
+ **model_inputs,
29
+ max_new_tokens=10
30
+ )
31
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
32
+
33
+ content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
34
+
35
+ return "yes" in content.lower()