deepbooru.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import os
  2. import re
  3. import torch
  4. from PIL import Image
  5. import numpy as np
  6. from modules import modelloader, paths, deepbooru_model, devices, images, shared
  7. re_special = re.compile(r'([\\()])')
  8. class DeepDanbooru:
  9. def __init__(self):
  10. self.model = None
  11. def load(self):
  12. if self.model is not None:
  13. return
  14. files = modelloader.load_models(
  15. model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
  16. model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
  17. ext_filter=[".pt"],
  18. download_name='model-resnet_custom_v3.pt',
  19. )
  20. self.model = deepbooru_model.DeepDanbooruModel()
  21. self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
  22. self.model.eval()
  23. self.model.to(devices.cpu, devices.dtype)
  24. def start(self):
  25. self.load()
  26. self.model.to(devices.device)
  27. def stop(self):
  28. if not shared.opts.interrogate_keep_models_in_memory:
  29. self.model.to(devices.cpu)
  30. devices.torch_gc()
  31. def tag(self, pil_image):
  32. self.start()
  33. res = self.tag_multi(pil_image)
  34. self.stop()
  35. return res
  36. def tag_multi(self, pil_image, force_disable_ranks=False):
  37. threshold = shared.opts.interrogate_deepbooru_score_threshold
  38. use_spaces = shared.opts.deepbooru_use_spaces
  39. use_escape = shared.opts.deepbooru_escape
  40. alpha_sort = shared.opts.deepbooru_sort_alpha
  41. include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
  42. pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
  43. a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
  44. with torch.no_grad(), devices.autocast():
  45. x = torch.from_numpy(a).to(devices.device)
  46. y = self.model(x)[0].detach().cpu().numpy()
  47. probability_dict = {}
  48. for tag, probability in zip(self.model.tags, y):
  49. if probability < threshold:
  50. continue
  51. if tag.startswith("rating:"):
  52. continue
  53. probability_dict[tag] = probability
  54. if alpha_sort:
  55. tags = sorted(probability_dict)
  56. else:
  57. tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
  58. res = []
  59. filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
  60. for tag in [x for x in tags if x not in filtertags]:
  61. probability = probability_dict[tag]
  62. tag_outformat = tag
  63. if use_spaces:
  64. tag_outformat = tag_outformat.replace('_', ' ')
  65. if use_escape:
  66. tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
  67. if include_ranks:
  68. tag_outformat = f"({tag_outformat}:{probability:.3f})"
  69. res.append(tag_outformat)
  70. return ", ".join(res)
  71. model = DeepDanbooru()