| from typing import List | |
| from PIL import Image | |
| from .common import CommonUpscaler, OfflineUpscaler | |
| from .waifu2x import Waifu2xUpscaler | |
| from .esrgan import ESRGANUpscaler | |
| from .esrgan_pytorch import ESRGANUpscalerPytorch | |
| UPSCALERS = { | |
| 'waifu2x': Waifu2xUpscaler, | |
| 'esrgan': ESRGANUpscaler, | |
| '4xultrasharp': ESRGANUpscalerPytorch, | |
| } | |
| upscaler_cache = {} | |
| def get_upscaler(key: str, *args, **kwargs) -> CommonUpscaler: | |
| if key not in UPSCALERS: | |
| raise ValueError(f'Could not find upscaler for: "{key}". Choose from the following: %s' % ','.join(UPSCALERS)) | |
| if not upscaler_cache.get(key): | |
| upscaler = UPSCALERS[key] | |
| upscaler_cache[key] = upscaler(*args, **kwargs) | |
| return upscaler_cache[key] | |
| async def prepare(upscaler_key: str): | |
| upscaler = get_upscaler(upscaler_key) | |
| if isinstance(upscaler, OfflineUpscaler): | |
| await upscaler.download() | |
| async def dispatch(upscaler_key: str, image_batch: List[Image.Image], upscale_ratio: int, device: str = 'cpu') -> List[Image.Image]: | |
| if upscale_ratio == 1: | |
| return image_batch | |
| upscaler = get_upscaler(upscaler_key) | |
| if isinstance(upscaler, OfflineUpscaler): | |
| await upscaler.load(device) | |
| return await upscaler.upscale(image_batch, upscale_ratio) | |