diopside commited on
Commit
7786f81
·
verified ·
1 Parent(s): 847cdc6

dynamically get VAE mapping before pairing with ControlNet

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -79,11 +79,13 @@ pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
79
  device_map="balanced"
80
  )
81
 
82
- print("Initial device map:", pipe.hf_device_map)
 
83
  # Expected output: {'transformer': 0, 'text_encoder': 1, 'vae': 2}
84
 
85
  # Move the controlnet to the same device as the VAE (cuda:2)
86
- vae_device = "cuda:2" # This is where the VAE is in the balanced config
 
87
  controlnet = QwenImageControlNetModel.from_pretrained(
88
  controlnet_model,
89
  torch_dtype=torch.bfloat16
 
79
  device_map="balanced"
80
  )
81
 
82
+ pipe_device_map = pipe.hf_device_map
83
+ print("Initial device map:", pipe_device_map)
84
  # Expected output: {'transformer': 0, 'text_encoder': 1, 'vae': 2}
85
 
86
  # Move the controlnet to the same device as the VAE (cuda:2)
87
+ vae_device = pipe_device_map['vae']
88
+ vae_device = f"cuda:{vae_device}" # This is where the VAE is in the balanced config
89
  controlnet = QwenImageControlNetModel.from_pretrained(
90
  controlnet_model,
91
  torch_dtype=torch.bfloat16