util_test.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. import os
  2. import random
  3. import numpy as np
  4. from PIL import Image
  5. import torch
  6. if __name__ != '__main__':
  7. import open_clip
  8. os.environ['CUDA_VISIBLE_DEVICES'] = ''
  9. def seed_all(seed = 0):
  10. torch.backends.cudnn.deterministic = True
  11. torch.backends.cudnn.benchmark = False
  12. torch.use_deterministic_algorithms(True, warn_only=False)
  13. random.seed(seed)
  14. np.random.seed(seed)
  15. torch.manual_seed(seed)
  16. def inference_text(model, model_name, batches):
  17. y = []
  18. tokenizer = open_clip.get_tokenizer(model_name)
  19. with torch.no_grad():
  20. for x in batches:
  21. x = tokenizer(x)
  22. y.append(model.encode_text(x))
  23. return torch.stack(y)
  24. def inference_image(model, preprocess_val, batches):
  25. y = []
  26. with torch.no_grad():
  27. for x in batches:
  28. x = torch.stack([preprocess_val(img) for img in x])
  29. y.append(model.encode_image(x))
  30. return torch.stack(y)
  31. def forward_model(model, model_name, preprocess_val, image_batch, text_batch):
  32. y = []
  33. tokenizer = open_clip.get_tokenizer(model_name)
  34. with torch.no_grad():
  35. for x_im, x_txt in zip(image_batch, text_batch):
  36. x_im = torch.stack([preprocess_val(im) for im in x_im])
  37. x_txt = tokenizer(x_txt)
  38. y.append(model(x_im, x_txt))
  39. if type(y[0]) == dict:
  40. out = {}
  41. for key in y[0].keys():
  42. out[key] = torch.stack([batch_out[key] for batch_out in y])
  43. else:
  44. out = []
  45. for i in range(len(y[0])):
  46. out.append(torch.stack([batch_out[i] for batch_out in y]))
  47. return out
  48. def random_image_batch(batch_size, size):
  49. h, w = size
  50. data = np.random.randint(255, size = (batch_size, h, w, 3), dtype = np.uint8)
  51. return [ Image.fromarray(d) for d in data ]
  52. def random_text_batch(batch_size, min_length = 75, max_length = 75):
  53. t = open_clip.tokenizer.SimpleTokenizer()
  54. # every token decoded as string, exclude SOT and EOT, replace EOW with space
  55. token_words = [
  56. x[1].replace('</w>', ' ')
  57. for x in t.decoder.items()
  58. if x[0] not in t.all_special_ids
  59. ]
  60. # strings of randomly chosen tokens
  61. return [
  62. ''.join(random.choices(
  63. token_words,
  64. k = random.randint(min_length, max_length)
  65. ))
  66. for _ in range(batch_size)
  67. ]
  68. def create_random_text_data(
  69. path,
  70. min_length = 75,
  71. max_length = 75,
  72. batches = 1,
  73. batch_size = 1
  74. ):
  75. text_batches = [
  76. random_text_batch(batch_size, min_length, max_length)
  77. for _ in range(batches)
  78. ]
  79. print(f"{path}")
  80. torch.save(text_batches, path)
  81. def create_random_image_data(path, size, batches = 1, batch_size = 1):
  82. image_batches = [
  83. random_image_batch(batch_size, size)
  84. for _ in range(batches)
  85. ]
  86. print(f"{path}")
  87. torch.save(image_batches, path)
  88. def get_data_dirs(make_dir = True):
  89. data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data')
  90. input_dir = os.path.join(data_dir, 'input')
  91. output_dir = os.path.join(data_dir, 'output')
  92. if make_dir:
  93. os.makedirs(input_dir, exist_ok = True)
  94. os.makedirs(output_dir, exist_ok = True)
  95. assert os.path.isdir(data_dir), f"data directory missing, expected at {input_dir}"
  96. assert os.path.isdir(data_dir), f"data directory missing, expected at {output_dir}"
  97. return input_dir, output_dir
  98. def create_test_data_for_model(
  99. model_name,
  100. pretrained = None,
  101. precision = 'fp32',
  102. jit = False,
  103. pretrained_hf = False,
  104. force_quick_gelu = False,
  105. create_missing_input_data = True,
  106. batches = 1,
  107. batch_size = 1,
  108. overwrite = False
  109. ):
  110. model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}'
  111. input_dir, output_dir = get_data_dirs()
  112. output_file_text = os.path.join(output_dir, f'{model_id}_random_text.pt')
  113. output_file_image = os.path.join(output_dir, f'{model_id}_random_image.pt')
  114. text_exists = os.path.exists(output_file_text)
  115. image_exists = os.path.exists(output_file_image)
  116. if not overwrite and text_exists and image_exists:
  117. return
  118. seed_all()
  119. model, _, preprocess_val = open_clip.create_model_and_transforms(
  120. model_name,
  121. pretrained = pretrained,
  122. precision = precision,
  123. jit = jit,
  124. force_quick_gelu = force_quick_gelu,
  125. pretrained_hf = pretrained_hf
  126. )
  127. # text
  128. if overwrite or not text_exists:
  129. input_file_text = os.path.join(input_dir, 'random_text.pt')
  130. if create_missing_input_data and not os.path.exists(input_file_text):
  131. create_random_text_data(
  132. input_file_text,
  133. batches = batches,
  134. batch_size = batch_size
  135. )
  136. assert os.path.isfile(input_file_text), f"missing input data, expected at {input_file_text}"
  137. input_data_text = torch.load(input_file_text)
  138. output_data_text = inference_text(model, model_name, input_data_text)
  139. print(f"{output_file_text}")
  140. torch.save(output_data_text, output_file_text)
  141. # image
  142. if overwrite or not image_exists:
  143. size = model.visual.image_size
  144. if not isinstance(size, tuple):
  145. size = (size, size)
  146. input_file_image = os.path.join(input_dir, f'random_image_{size[0]}_{size[1]}.pt')
  147. if create_missing_input_data and not os.path.exists(input_file_image):
  148. create_random_image_data(
  149. input_file_image,
  150. size,
  151. batches = batches,
  152. batch_size = batch_size
  153. )
  154. assert os.path.isfile(input_file_image), f"missing input data, expected at {input_file_image}"
  155. input_data_image = torch.load(input_file_image)
  156. output_data_image = inference_image(model, preprocess_val, input_data_image)
  157. print(f"{output_file_image}")
  158. torch.save(output_data_image, output_file_image)
  159. def create_test_data(
  160. models,
  161. batches = 1,
  162. batch_size = 1,
  163. overwrite = False
  164. ):
  165. models = list(set(models).difference({
  166. # not available with timm
  167. # see https://github.com/mlfoundations/open_clip/issues/219
  168. 'timm-convnext_xlarge',
  169. 'timm-vit_medium_patch16_gap_256'
  170. }).intersection(open_clip.list_models()))
  171. models.sort()
  172. print(f"generating test data for:\n{models}")
  173. for model_name in models:
  174. print(model_name)
  175. create_test_data_for_model(
  176. model_name,
  177. batches = batches,
  178. batch_size = batch_size,
  179. overwrite = overwrite
  180. )
  181. return models
  182. def _sytem_assert(string):
  183. assert os.system(string) == 0
  184. class TestWrapper(torch.nn.Module):
  185. output_dict: torch.jit.Final[bool]
  186. def __init__(self, model, model_name, output_dict=True) -> None:
  187. super().__init__()
  188. self.model = model
  189. self.output_dict = output_dict
  190. if type(model) in [open_clip.CLIP, open_clip.CustomTextCLIP]:
  191. self.model.output_dict = self.output_dict
  192. config = open_clip.get_model_config(model_name)
  193. self.head = torch.nn.Linear(config["embed_dim"], 2)
  194. def forward(self, image, text):
  195. x = self.model(image, text)
  196. if self.output_dict:
  197. out = self.head(x["image_features"])
  198. else:
  199. out = self.head(x[0])
  200. return {"test_output": out}
  201. def main(args):
  202. global open_clip
  203. import importlib
  204. import shutil
  205. import subprocess
  206. import argparse
  207. parser = argparse.ArgumentParser(description = "Populate test data directory")
  208. parser.add_argument(
  209. '-a', '--all',
  210. action = 'store_true',
  211. help = "create test data for all models"
  212. )
  213. parser.add_argument(
  214. '-m', '--model',
  215. type = str,
  216. default = [],
  217. nargs = '+',
  218. help = "model(s) to create test data for"
  219. )
  220. parser.add_argument(
  221. '-f', '--model_list',
  222. type = str,
  223. help = "path to a text file containing a list of model names, one model per line"
  224. )
  225. parser.add_argument(
  226. '-s', '--save_model_list',
  227. type = str,
  228. help = "path to save the list of models that data was generated for"
  229. )
  230. parser.add_argument(
  231. '-g', '--git_revision',
  232. type = str,
  233. help = "git revision to generate test data for"
  234. )
  235. parser.add_argument(
  236. '--overwrite',
  237. action = 'store_true',
  238. help = "overwrite existing output data"
  239. )
  240. parser.add_argument(
  241. '-n', '--num_batches',
  242. default = 1,
  243. type = int,
  244. help = "amount of data batches to create (default: 1)"
  245. )
  246. parser.add_argument(
  247. '-b', '--batch_size',
  248. default = 1,
  249. type = int,
  250. help = "test data batch size (default: 1)"
  251. )
  252. args = parser.parse_args(args)
  253. model_list = []
  254. if args.model_list is not None:
  255. with open(args.model_list, 'r') as f:
  256. model_list = f.read().splitlines()
  257. if not args.all and len(args.model) < 1 and len(model_list) < 1:
  258. print("error: at least one model name is required")
  259. parser.print_help()
  260. parser.exit(1)
  261. if args.git_revision is not None:
  262. stash_output = subprocess.check_output(['git', 'stash']).decode().splitlines()
  263. has_stash = len(stash_output) > 0 and stash_output[0] != 'No local changes to save'
  264. current_branch = subprocess.check_output(['git', 'branch', '--show-current'])
  265. if len(current_branch) < 1:
  266. # not on a branch -> detached head
  267. current_branch = subprocess.check_output(['git', 'rev-parse', 'HEAD'])
  268. current_branch = current_branch.splitlines()[0].decode()
  269. try:
  270. _sytem_assert(f'git checkout {args.git_revision}')
  271. except AssertionError as e:
  272. _sytem_assert(f'git checkout -f {current_branch}')
  273. if has_stash:
  274. os.system(f'git stash pop')
  275. raise e
  276. open_clip = importlib.import_module('open_clip')
  277. models = open_clip.list_models() if args.all else args.model + model_list
  278. try:
  279. models = create_test_data(
  280. models,
  281. batches = args.num_batches,
  282. batch_size = args.batch_size,
  283. overwrite = args.overwrite
  284. )
  285. finally:
  286. if args.git_revision is not None:
  287. test_dir = os.path.join(os.path.dirname(__file__), 'data')
  288. test_dir_ref = os.path.join(os.path.dirname(__file__), 'data_ref')
  289. if os.path.exists(test_dir_ref):
  290. shutil.rmtree(test_dir_ref, ignore_errors = True)
  291. if os.path.exists(test_dir):
  292. os.rename(test_dir, test_dir_ref)
  293. _sytem_assert(f'git checkout {current_branch}')
  294. if has_stash:
  295. os.system(f'git stash pop')
  296. os.rename(test_dir_ref, test_dir)
  297. if args.save_model_list is not None:
  298. print(f"Saving model list as {args.save_model_list}")
  299. with open(args.save_model_list, 'w') as f:
  300. for m in models:
  301. print(m, file=f)
  302. if __name__ == '__main__':
  303. import sys
  304. main(sys.argv[1:])