euIaxs22 commited on
Commit
ed81d43
·
verified ·
1 Parent(s): 13d78bf

Create vince_server.py

Browse files
Files changed (1) hide show
  1. services/vince_server.py +43 -0
services/vince_server.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # services/vince_server.py
2
+ from common.config import load_config, create_object
3
+ from pathlib import Path
4
+ import gc, torch
5
+
6
+ class VinceSingleton:
7
+ def __init__(self, config_path: str, overrides: list[str]):
8
+ self.config = load_config(config_path, overrides)
9
+ self.gen = create_object(self.config)
10
+ self.gen.configure_persistence()
11
+ self.gen.configure_models()
12
+ self.gen.configure_diffusion()
13
+
14
+ def _set_steps(self, steps: int | None):
15
+ if steps and hasattr(self.gen, "sampler") and hasattr(self.gen.sampler, "timesteps"):
16
+ ts = self.gen.sampler.timesteps
17
+ if hasattr(ts, "__len__") and len(ts) > 0:
18
+ steps = min(int(steps), len(ts))
19
+ if steps < len(ts):
20
+ idx = torch.linspace(0, len(ts) - 1, steps).round().long().tolist()
21
+ self.gen.sampler.timesteps = [ts[i] for i in idx]
22
+
23
+ def generate_multi_turn(self, image_path, turns, out_dir, *, steps=None, cfg_scale=None, aspect_ratio=None, resolution=None):
24
+ g = self.gen.config.generation
25
+ g.output.dir = str(out_dir)
26
+ g.positive_prompt = {"image_path": [str(image_path)], "prompts": list(turns)}
27
+ if cfg_scale is not None: g.cfg_scale = float(cfg_scale)
28
+ if aspect_ratio is not None: g.aspect_ratio = str(aspect_ratio)
29
+ if resolution is not None: g.resolution = int(resolution)
30
+ self._set_steps(steps)
31
+ self.gen.inference_loop()
32
+ try:
33
+ torch.cuda.synchronize()
34
+ except Exception:
35
+ pass
36
+ gc.collect()
37
+ try:
38
+ torch.cuda.empty_cache()
39
+ torch.cuda.memory.reset_peak_memory_stats()
40
+ except Exception:
41
+ pass
42
+ return str(out_dir)
43
+