ui_tempdir.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import os
  2. import tempfile
  3. from collections import namedtuple
  4. from pathlib import Path
  5. import gradio as gr
  6. from PIL import PngImagePlugin
  7. from modules import shared
  8. Savedfile = namedtuple("Savedfile", ["name"])
  9. def register_tmp_file(gradio, filename):
  10. if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
  11. gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
  12. if hasattr(gradio, 'temp_dirs'): # gradio 3.9
  13. gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
  14. def check_tmp_file(gradio, filename):
  15. if hasattr(gradio, 'temp_file_sets'):
  16. return any([filename in fileset for fileset in gradio.temp_file_sets])
  17. if hasattr(gradio, 'temp_dirs'):
  18. return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
  19. return False
  20. def save_pil_to_file(pil_image, dir=None):
  21. already_saved_as = getattr(pil_image, 'already_saved_as', None)
  22. if already_saved_as and os.path.isfile(already_saved_as):
  23. register_tmp_file(shared.demo, already_saved_as)
  24. file_obj = Savedfile(already_saved_as)
  25. return file_obj
  26. if shared.opts.temp_dir != "":
  27. dir = shared.opts.temp_dir
  28. use_metadata = False
  29. metadata = PngImagePlugin.PngInfo()
  30. for key, value in pil_image.info.items():
  31. if isinstance(key, str) and isinstance(value, str):
  32. metadata.add_text(key, value)
  33. use_metadata = True
  34. file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
  35. pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
  36. return file_obj
  37. # override save to file function so that it also writes PNG info
  38. gr.processing_utils.save_pil_to_file = save_pil_to_file
  39. def on_tmpdir_changed():
  40. if shared.opts.temp_dir == "" or shared.demo is None:
  41. return
  42. os.makedirs(shared.opts.temp_dir, exist_ok=True)
  43. register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
  44. def cleanup_tmpdr():
  45. temp_dir = shared.opts.temp_dir
  46. if temp_dir == "" or not os.path.isdir(temp_dir):
  47. return
  48. for root, dirs, files in os.walk(temp_dir, topdown=False):
  49. for name in files:
  50. _, extension = os.path.splitext(name)
  51. if extension != ".png":
  52. continue
  53. filename = os.path.join(root, name)
  54. os.remove(filename)