diopside commited on
Commit
098bca5
·
verified ·
1 Parent(s): 4f1e83c

Utilize HF's balanced device_map + move diffusion components in relevant execution cores

Browse files
Files changed (1) hide show
  1. app.py +33 -8
app.py CHANGED
@@ -71,12 +71,37 @@ def use_output_as_input(output_image):
71
  base_model = "Qwen/Qwen-Image"
72
  controlnet_model = "InstantX/Qwen-Image-ControlNet-Inpainting"
73
 
74
- controlnet = QwenImageControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
75
-
76
  pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
77
- base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
 
 
 
78
  )
79
- pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  @spaces.GPU(duration=150)
@@ -93,7 +118,7 @@ def infer(edit_images,
93
 
94
  image = edit_images["background"]
95
  mask = edit_images["layers"][0]
96
-
97
  if randomize_seed:
98
  seed = random.randint(0, MAX_SEED)
99
 
@@ -113,7 +138,7 @@ def infer(edit_images,
113
  width=image.size[0],
114
  height=image.size[1],
115
  true_cfg_scale=true_cfg_scale,
116
- generator=torch.Generator(device="cuda").manual_seed(seed)
117
  ).images[0]
118
 
119
  return [image, result_image], seed
@@ -140,9 +165,9 @@ css = """
140
 
141
 
142
  with gr.Blocks(css=css, theme=gr.themes.Citrus()) as demo:
143
- gr.HTML("<h1 style='text-align: center'>Qwen-Image with InstantX Inpainting ControlNet</style>")
144
  gr.Markdown(
145
- "Inpaint images with [InstantX/Qwen-Image-ControlNet-Inpainting](https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting)"
146
  )
147
  with gr.Row():
148
  with gr.Column():
 
71
  base_model = "Qwen/Qwen-Image"
72
  controlnet_model = "InstantX/Qwen-Image-ControlNet-Inpainting"
73
 
74
+ # First create the pipeline with device_map="balanced"
 
75
  pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
76
+ base_model,
77
+ controlnet=None, # We'll add the controlnet later
78
+ torch_dtype=torch.bfloat16,
79
+ device_map="balanced"
80
  )
81
+
82
+ print("Initial device map:", pipe.hf_device_map)
83
+ # Expected output: {'transformer': 0, 'text_encoder': 1, 'vae': 2}
84
+
85
+ # Move the controlnet to the same device as the VAE (cuda:2)
86
+ vae_device = "cuda:2" # This is where the VAE is in the balanced config
87
+ controlnet = QwenImageControlNetModel.from_pretrained(
88
+ controlnet_model,
89
+ torch_dtype=torch.bfloat16
90
+ ).to(vae_device)
91
+
92
+ # Attach the controlnet to the pipeline
93
+ pipe.controlnet = controlnet
94
+
95
+ pipe.enable_vae_slicing()
96
+ pipe.enable_vae_tiling()
97
+
98
+ print("Controlnet device:", next(pipe.controlnet.parameters()).device)
99
+ print("VAE device:", next(pipe.vae.parameters()).device)
100
+
101
+
102
+ # Create a helper function to get a generator on the correct device
103
+ def get_generator(seed):
104
+ return torch.Generator(device=vae_device).manual_seed(seed)
105
 
106
 
107
  @spaces.GPU(duration=150)
 
118
 
119
  image = edit_images["background"]
120
  mask = edit_images["layers"][0]
121
+
122
  if randomize_seed:
123
  seed = random.randint(0, MAX_SEED)
124
 
 
138
  width=image.size[0],
139
  height=image.size[1],
140
  true_cfg_scale=true_cfg_scale,
141
+ generator=get_generator(seed)
142
  ).images[0]
143
 
144
  return [image, result_image], seed
 
165
 
166
 
167
  with gr.Blocks(css=css, theme=gr.themes.Citrus()) as demo:
168
+ gr.HTML("<h1 style='text-align: center'>MultiGPU - Qwen-Image + InstantX Inpainting ControlNet - LB</style>")
169
  gr.Markdown(
170
+ "Runs on 4*L40s instead of zeroGPU - Faster, Efficient Inpaint images with [InstantX/Qwen-Image-ControlNet-Inpainting](https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting)"
171
  )
172
  with gr.Row():
173
  with gr.Column():