Spaces:
Runtime error
Runtime error
dynamically get VAE mapping before pairing with ControlNet
Browse files
app.py
CHANGED
|
@@ -79,11 +79,13 @@ pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
|
|
| 79 |
device_map="balanced"
|
| 80 |
)
|
| 81 |
|
| 82 |
-
|
|
|
|
| 83 |
# Expected output: {'transformer': 0, 'text_encoder': 1, 'vae': 2}
|
| 84 |
|
| 85 |
# Move the controlnet to the same device as the VAE (cuda:2)
|
| 86 |
-
vae_device =
|
|
|
|
| 87 |
controlnet = QwenImageControlNetModel.from_pretrained(
|
| 88 |
controlnet_model,
|
| 89 |
torch_dtype=torch.bfloat16
|
|
|
|
| 79 |
device_map="balanced"
|
| 80 |
)
|
| 81 |
|
| 82 |
+
pipe_device_map = pipe.hf_device_map
|
| 83 |
+
print("Initial device map:", pipe_device_map)
|
| 84 |
# Expected output: {'transformer': 0, 'text_encoder': 1, 'vae': 2}
|
| 85 |
|
| 86 |
# Move the controlnet to the same device as the VAE (cuda:2)
|
| 87 |
+
vae_device = pipe_device_map['vae']
|
| 88 |
+
vae_device = f"cuda:{vae_device}" # This is where the VAE is in the balanced config
|
| 89 |
controlnet = QwenImageControlNetModel.from_pretrained(
|
| 90 |
controlnet_model,
|
| 91 |
torch_dtype=torch.bfloat16
|