interrogate.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import os
  2. import sys
  3. import traceback
  4. from collections import namedtuple
  5. from pathlib import Path
  6. import re
  7. import torch
  8. import torch.hub
  9. from torchvision import transforms
  10. from torchvision.transforms.functional import InterpolationMode
  11. import modules.shared as shared
  12. from modules import devices, paths, shared, lowvram, modelloader, errors
  13. blip_image_eval_size = 384
  14. clip_model_name = 'ViT-L/14'
  15. Category = namedtuple("Category", ["name", "topn", "items"])
  16. re_topn = re.compile(r"\.top(\d+)\.")
  17. def category_types():
  18. return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
  19. def download_default_clip_interrogate_categories(content_dir):
  20. print("Downloading CLIP categories...")
  21. tmpdir = content_dir + "_tmp"
  22. category_types = ["artists", "flavors", "mediums", "movements"]
  23. try:
  24. os.makedirs(tmpdir)
  25. for category_type in category_types:
  26. torch.hub.download_url_to_file(f"https://www.cool33.com/clip/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
  27. os.rename(tmpdir, content_dir)
  28. except Exception as e:
  29. errors.display(e, "downloading default CLIP interrogate categories")
  30. finally:
  31. if os.path.exists(tmpdir):
  32. os.remove(tmpdir)
  33. class InterrogateModels:
  34. blip_model = None
  35. clip_model = None
  36. clip_preprocess = None
  37. dtype = None
  38. running_on_cpu = None
  39. def __init__(self, content_dir):
  40. self.loaded_categories = None
  41. self.skip_categories = []
  42. self.content_dir = content_dir
  43. self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
  44. def categories(self):
  45. if not os.path.exists(self.content_dir):
  46. download_default_clip_interrogate_categories(self.content_dir)
  47. if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
  48. return self.loaded_categories
  49. self.loaded_categories = []
  50. if os.path.exists(self.content_dir):
  51. self.skip_categories = shared.opts.interrogate_clip_skip_categories
  52. category_types = []
  53. for filename in Path(self.content_dir).glob('*.txt'):
  54. category_types.append(filename.stem)
  55. if filename.stem in self.skip_categories:
  56. continue
  57. m = re_topn.search(filename.stem)
  58. topn = 1 if m is None else int(m.group(1))
  59. with open(filename, "r", encoding="utf8") as file:
  60. lines = [x.strip() for x in file.readlines()]
  61. self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
  62. return self.loaded_categories
  63. def create_fake_fairscale(self):
  64. class FakeFairscale:
  65. def checkpoint_wrapper(self):
  66. pass
  67. sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
  68. def load_blip_model(self):
  69. self.create_fake_fairscale()
  70. import models.blip
  71. files = modelloader.load_models(
  72. model_path=os.path.join(paths.models_path, "BLIP"),
  73. model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
  74. ext_filter=[".pth"],
  75. download_name='model_base_caption_capfilt_large.pth',
  76. )
  77. blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
  78. blip_model.eval()
  79. return blip_model
  80. def load_clip_model(self):
  81. import clip
  82. if self.running_on_cpu:
  83. model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
  84. else:
  85. model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
  86. model.eval()
  87. model = model.to(devices.device_interrogate)
  88. return model, preprocess
  89. def load(self):
  90. if self.blip_model is None:
  91. self.blip_model = self.load_blip_model()
  92. if not shared.cmd_opts.no_half and not self.running_on_cpu:
  93. self.blip_model = self.blip_model.half()
  94. self.blip_model = self.blip_model.to(devices.device_interrogate)
  95. if self.clip_model is None:
  96. self.clip_model, self.clip_preprocess = self.load_clip_model()
  97. if not shared.cmd_opts.no_half and not self.running_on_cpu:
  98. self.clip_model = self.clip_model.half()
  99. self.clip_model = self.clip_model.to(devices.device_interrogate)
  100. self.dtype = next(self.clip_model.parameters()).dtype
  101. def send_clip_to_ram(self):
  102. if not shared.opts.interrogate_keep_models_in_memory:
  103. if self.clip_model is not None:
  104. self.clip_model = self.clip_model.to(devices.cpu)
  105. def send_blip_to_ram(self):
  106. if not shared.opts.interrogate_keep_models_in_memory:
  107. if self.blip_model is not None:
  108. self.blip_model = self.blip_model.to(devices.cpu)
  109. def unload(self):
  110. self.send_clip_to_ram()
  111. self.send_blip_to_ram()
  112. devices.torch_gc()
  113. def rank(self, image_features, text_array, top_count=1):
  114. import clip
  115. devices.torch_gc()
  116. if shared.opts.interrogate_clip_dict_limit != 0:
  117. text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
  118. top_count = min(top_count, len(text_array))
  119. text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
  120. text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
  121. text_features /= text_features.norm(dim=-1, keepdim=True)
  122. similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
  123. for i in range(image_features.shape[0]):
  124. similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
  125. similarity /= image_features.shape[0]
  126. top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
  127. return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
  128. def generate_caption(self, pil_image):
  129. gpu_image = transforms.Compose([
  130. transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
  131. transforms.ToTensor(),
  132. transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
  133. ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
  134. with torch.no_grad():
  135. caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
  136. return caption[0]
  137. def interrogate(self, pil_image):
  138. res = ""
  139. shared.state.begin()
  140. shared.state.job = 'interrogate'
  141. try:
  142. if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
  143. lowvram.send_everything_to_cpu()
  144. devices.torch_gc()
  145. self.load()
  146. caption = self.generate_caption(pil_image)
  147. self.send_blip_to_ram()
  148. devices.torch_gc()
  149. res = caption
  150. clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
  151. with torch.no_grad(), devices.autocast():
  152. image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
  153. image_features /= image_features.norm(dim=-1, keepdim=True)
  154. for name, topn, items in self.categories():
  155. matches = self.rank(image_features, items, top_count=topn)
  156. for match, score in matches:
  157. if shared.opts.interrogate_return_ranks:
  158. res += f", ({match}:{score/100:.3f})"
  159. else:
  160. res += ", " + match
  161. except Exception:
  162. print("Error interrogating", file=sys.stderr)
  163. print(traceback.format_exc(), file=sys.stderr)
  164. res += "<error>"
  165. self.unload()
  166. shared.state.end()
  167. return res