upscaler.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import os
  2. from abc import abstractmethod
  3. import PIL
  4. import numpy as np
  5. import torch
  6. from PIL import Image
  7. import modules.shared
  8. from modules import modelloader, shared
  9. LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
  10. NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
  11. class Upscaler:
  12. name = None
  13. model_path = None
  14. model_name = None
  15. model_url = None
  16. enable = True
  17. filter = None
  18. model = None
  19. user_path = None
  20. scalers: []
  21. tile = True
  22. def __init__(self, create_dirs=False):
  23. self.mod_pad_h = None
  24. self.tile_size = modules.shared.opts.ESRGAN_tile
  25. self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
  26. self.device = modules.shared.device
  27. self.img = None
  28. self.output = None
  29. self.scale = 1
  30. self.half = not modules.shared.cmd_opts.no_half
  31. self.pre_pad = 0
  32. self.mod_scale = None
  33. if self.model_path is None and self.name:
  34. self.model_path = os.path.join(shared.models_path, self.name)
  35. if self.model_path and create_dirs:
  36. os.makedirs(self.model_path, exist_ok=True)
  37. try:
  38. import cv2
  39. self.can_tile = True
  40. except:
  41. pass
  42. @abstractmethod
  43. def do_upscale(self, img: PIL.Image, selected_model: str):
  44. return img
  45. def upscale(self, img: PIL.Image, scale, selected_model: str = None):
  46. self.scale = scale
  47. dest_w = int(img.width * scale)
  48. dest_h = int(img.height * scale)
  49. for i in range(3):
  50. shape = (img.width, img.height)
  51. img = self.do_upscale(img, selected_model)
  52. if shape == (img.width, img.height):
  53. break
  54. if img.width >= dest_w and img.height >= dest_h:
  55. break
  56. if img.width != dest_w or img.height != dest_h:
  57. img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
  58. return img
  59. @abstractmethod
  60. def load_model(self, path: str):
  61. pass
  62. def find_models(self, ext_filter=None) -> list:
  63. return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
  64. def update_status(self, prompt):
  65. print(f"\nextras: {prompt}", file=shared.progress_print_out)
  66. class UpscalerData:
  67. name = None
  68. data_path = None
  69. scale: int = 4
  70. scaler: Upscaler = None
  71. model: None
  72. def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
  73. self.name = name
  74. self.data_path = path
  75. self.local_data_path = path
  76. self.scaler = upscaler
  77. self.scale = scale
  78. self.model = model
  79. class UpscalerNone(Upscaler):
  80. name = "None"
  81. scalers = []
  82. def load_model(self, path):
  83. pass
  84. def do_upscale(self, img, selected_model=None):
  85. return img
  86. def __init__(self, dirname=None):
  87. super().__init__(False)
  88. self.scalers = [UpscalerData("None", None, self)]
  89. class UpscalerLanczos(Upscaler):
  90. scalers = []
  91. def do_upscale(self, img, selected_model=None):
  92. return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
  93. def load_model(self, _):
  94. pass
  95. def __init__(self, dirname=None):
  96. super().__init__(False)
  97. self.name = "Lanczos"
  98. self.scalers = [UpscalerData("Lanczos", None, self)]
  99. class UpscalerNearest(Upscaler):
  100. scalers = []
  101. def do_upscale(self, img, selected_model=None):
  102. return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
  103. def load_model(self, _):
  104. pass
  105. def __init__(self, dirname=None):
  106. super().__init__(False)
  107. self.name = "Nearest"
  108. self.scalers = [UpscalerData("Nearest", None, self)]