utils.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import torch
  2. import os
  3. import functools
  4. import time
  5. import base64
  6. import numpy as np
  7. import gradio as gr
  8. import logging
  9. from typing import Any, Callable, Dict
  10. from scripts.logging import logger
  11. def load_state_dict(ckpt_path, location="cpu"):
  12. _, extension = os.path.splitext(ckpt_path)
  13. if extension.lower() == ".safetensors":
  14. import safetensors.torch
  15. state_dict = safetensors.torch.load_file(ckpt_path, device=location)
  16. else:
  17. state_dict = get_state_dict(
  18. torch.load(ckpt_path, map_location=torch.device(location))
  19. )
  20. state_dict = get_state_dict(state_dict)
  21. logger.info(f"Loaded state_dict from [{ckpt_path}]")
  22. return state_dict
  23. def get_state_dict(d):
  24. return d.get("state_dict", d)
  25. def ndarray_lru_cache(max_size: int = 128, typed: bool = False):
  26. """
  27. Decorator to enable caching for functions with numpy array arguments.
  28. Numpy arrays are mutable, and thus not directly usable as hash keys.
  29. The idea here is to wrap the incoming arguments with type `np.ndarray`
  30. as `HashableNpArray` so that `lru_cache` can correctly handles `np.ndarray`
  31. arguments.
  32. `HashableNpArray` functions exactly the same way as `np.ndarray` except
  33. having `__hash__` and `__eq__` overriden.
  34. """
  35. def decorator(func: Callable):
  36. """The actual decorator that accept function as input."""
  37. class HashableNpArray(np.ndarray):
  38. def __new__(cls, input_array):
  39. # Input array is an instance of ndarray.
  40. # The view makes the input array and returned array share the same data.
  41. obj = np.asarray(input_array).view(cls)
  42. return obj
  43. def __eq__(self, other) -> bool:
  44. return np.array_equal(self, other)
  45. def __hash__(self):
  46. # Hash the bytes representing the data of the array.
  47. return hash(self.tobytes())
  48. @functools.lru_cache(maxsize=max_size, typed=typed)
  49. def cached_func(*args, **kwargs):
  50. """This function only accepts `HashableNpArray` as input params."""
  51. return func(*args, **kwargs)
  52. # Preserves original function.__name__ and __doc__.
  53. @functools.wraps(func)
  54. def decorated_func(*args, **kwargs):
  55. """The decorated function that delegates the original function."""
  56. def convert_item(item: Any):
  57. return HashableNpArray(item) if isinstance(item, np.ndarray) else item
  58. args = [convert_item(arg) for arg in args]
  59. kwargs = {k: convert_item(arg) for k, arg in kwargs.items()}
  60. return cached_func(*args, **kwargs)
  61. return decorated_func
  62. return decorator
  63. def timer_decorator(func):
  64. """Time the decorated function and output the result to debug logger."""
  65. if logger.level != logging.DEBUG:
  66. return func
  67. @functools.wraps(func)
  68. def wrapper(*args, **kwargs):
  69. start_time = time.time()
  70. result = func(*args, **kwargs)
  71. end_time = time.time()
  72. duration = end_time - start_time
  73. # Only report function that are significant enough.
  74. if duration > 1e-3:
  75. logger.debug(f"{func.__name__} ran in: {duration} sec")
  76. return result
  77. return wrapper
  78. class TimeMeta(type):
  79. """ Metaclass to record execution time on all methods of the
  80. child class. """
  81. def __new__(cls, name, bases, attrs):
  82. for attr_name, attr_value in attrs.items():
  83. if callable(attr_value):
  84. attrs[attr_name] = timer_decorator(attr_value)
  85. return super().__new__(cls, name, bases, attrs)
  86. # svgsupports
  87. svgsupport = False
  88. try:
  89. import io
  90. from svglib.svglib import svg2rlg
  91. from reportlab.graphics import renderPM
  92. svgsupport = True
  93. except ImportError:
  94. pass
  95. def svg_preprocess(inputs: Dict, preprocess: Callable):
  96. if not inputs:
  97. return None
  98. if inputs["image"].startswith("data:image/svg+xml;base64,") and svgsupport:
  99. svg_data = base64.b64decode(
  100. inputs["image"].replace("data:image/svg+xml;base64,", "")
  101. )
  102. drawing = svg2rlg(io.BytesIO(svg_data))
  103. png_data = renderPM.drawToString(drawing, fmt="PNG")
  104. encoded_string = base64.b64encode(png_data)
  105. base64_str = str(encoded_string, "utf-8")
  106. base64_str = "data:image/png;base64," + base64_str
  107. inputs["image"] = base64_str
  108. return preprocess(inputs)
  109. def get_unique_axis0(data):
  110. arr = np.asanyarray(data)
  111. idxs = np.lexsort(arr.T)
  112. arr = arr[idxs]
  113. unique_idxs = np.empty(len(arr), dtype=np.bool_)
  114. unique_idxs[:1] = True
  115. unique_idxs[1:] = np.any(arr[:-1, :] != arr[1:, :], axis=-1)
  116. return arr[unique_idxs]