main.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. import io
  3. import json
  4. import numpy as np
  5. import cv2
  6. import gradio as gr
  7. import modules.scripts as scripts
  8. from modules import script_callbacks
  9. from modules.shared import opts
  10. from modules.paths import models_path
  11. from basicsr.utils.download_util import load_file_from_url
  12. from scripts.openpose.body import Body
  13. from PIL import Image
  14. body_estimation = None
  15. presets_file = os.path.join(scripts.basedir(), "presets.json")
  16. presets = {}
  17. try:
  18. with open(presets_file) as file:
  19. presets = json.load(file)
  20. except FileNotFoundError:
  21. pass
  22. def pil2cv(in_image):
  23. out_image = np.array(in_image, dtype=np.uint8)
  24. if out_image.shape[2] == 3:
  25. out_image = cv2.cvtColor(out_image, cv2.COLOR_RGB2BGR)
  26. return out_image
  27. def candidate2li(li):
  28. res = []
  29. for x, y, *_ in li:
  30. res.append([x, y])
  31. return res
  32. def subset2li(li):
  33. res = []
  34. for r in li:
  35. for c in r:
  36. res.append(c)
  37. return res
  38. class Script(scripts.Script):
  39. def __init__(self) -> None:
  40. super().__init__()
  41. def title(self):
  42. return "OpenPose Editor"
  43. def show(self, is_img2img):
  44. return scripts.AlwaysVisible
  45. def ui(self, is_img2img):
  46. return ()
  47. def on_ui_tabs():
  48. with gr.Blocks(analytics_enabled=False) as openpose_editor:
  49. with gr.Row():
  50. with gr.Column():
  51. width = gr.Slider(label="width", minimum=64, maximum=2048, value=512, step=64, interactive=True)
  52. height = gr.Slider(label="height", minimum=64, maximum=2048, value=512, step=64, interactive=True)
  53. with gr.Row():
  54. add = gr.Button(value="Add", variant="primary")
  55. # delete = gr.Button(value="Delete")
  56. with gr.Row():
  57. reset_btn = gr.Button(value="Reset")
  58. json_input = gr.UploadButton(label="Load from JSON", file_types=[".json"], elem_id="openpose_json_button")
  59. png_input = gr.UploadButton(label="Detect from Image", file_types=["image"], type="bytes", elem_id="openpose_detect_button")
  60. bg_input = gr.UploadButton(label="Add Background Image", file_types=["image"], elem_id="openpose_bg_button")
  61. with gr.Row():
  62. preset_list = gr.Dropdown(label="Presets", choices=sorted(presets.keys()), interactive=True)
  63. preset_load = gr.Button(value="Load Preset")
  64. preset_save = gr.Button(value="Save Preset")
  65. with gr.Column():
  66. # gradioooooo...
  67. canvas = gr.HTML('<canvas id="openpose_editor_canvas" width="512" height="512" style="margin: 0.25rem; border-radius: 0.25rem; border: 0.5px solid"></canvas>')
  68. jsonbox = gr.Text(label="json", elem_id="jsonbox", visible=False)
  69. with gr.Row():
  70. json_output = gr.Button(value="Save JSON")
  71. png_output = gr.Button(value="Save PNG")
  72. send_t2t = gr.Button(value="Send to txt2img")
  73. send_i2i = gr.Button(value="Send to img2img")
  74. control_net_max_models_num = getattr(opts, 'control_net_max_models_num', 0)
  75. select_target_index = gr.Dropdown([str(i) for i in range(control_net_max_models_num)], label="Send to", value="0", interactive=True, visible=(control_net_max_models_num > 1))
  76. def estimate(file):
  77. global body_estimation
  78. if body_estimation is None:
  79. model_path = os.path.join(models_path, "openpose", "body_pose_model.pth")
  80. if not os.path.isfile(model_path):
  81. body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
  82. load_file_from_url(body_model_path, model_dir=os.path.join(models_path, "openpose"))
  83. body_estimation = Body(model_path)
  84. stream = io.BytesIO(file)
  85. img = Image.open(stream)
  86. candidate, subset = body_estimation(pil2cv(img))
  87. result = {
  88. "candidate": candidate2li(candidate),
  89. "subset": subset2li(subset),
  90. }
  91. return str(result).replace("'", '"')
  92. def savePreset(name, data):
  93. if name:
  94. presets[name] = json.loads(data)
  95. with open(presets_file, "w") as file:
  96. json.dump(presets, file)
  97. return gr.update(choices=sorted(presets.keys()), value=name), json.dumps(data)
  98. return gr.update(), gr.update()
  99. dummy_component = gr.Label(visible=False)
  100. preset = gr.Text(visible=False)
  101. width.change(None, [width, height], None, _js="(w, h) => {resizeCanvas(w, h)}")
  102. height.change(None, [width, height], None, _js="(w, h) => {resizeCanvas(w, h)}")
  103. png_output.click(None, [], None, _js="savePNG")
  104. bg_input.upload(None, [bg_input], [width, height], _js="addBackground")
  105. png_input.upload(estimate, png_input, jsonbox)
  106. png_input.upload(None, png_input, [width, height], _js="addBackground")
  107. add.click(None, [], None, _js="addPose")
  108. send_t2t.click(None, select_target_index, None, _js="(i) => {sendImage('txt2img', i)}")
  109. send_i2i.click(None, select_target_index, None, _js="(i) => {sendImage('img2img', i)}")
  110. reset_btn.click(None, [], None, _js="resetCanvas")
  111. json_input.upload(None, json_input, [width, height], _js="loadJSON")
  112. json_output.click(None, None, None, _js="saveJSON")
  113. preset_save.click(savePreset, [dummy_component, dummy_component], [preset_list, preset], _js="savePreset")
  114. preset_load.click(None, preset, [width, height], _js="loadPreset")
  115. preset_list.change(lambda selected: json.dumps(presets[selected]), preset_list, preset)
  116. return [(openpose_editor, "OpenPose Editor", "openpose_editor")]
  117. script_callbacks.on_ui_tabs(on_ui_tabs)