test_inference.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import os
  2. import pytest
  3. import torch
  4. import open_clip
  5. import util_test
  6. os.environ['CUDA_VISIBLE_DEVICES'] = ''
  7. if hasattr(torch._C, '_jit_set_profiling_executor'):
  8. # legacy executor is too slow to compile large models for unit tests
  9. # no need for the fusion performance here
  10. torch._C._jit_set_profiling_executor(True)
  11. torch._C._jit_set_profiling_mode(False)
  12. models_to_test = set(open_clip.list_models())
  13. # testing excemptions
  14. models_to_test = models_to_test.difference({
  15. # not available with timm yet
  16. # see https://github.com/mlfoundations/open_clip/issues/219
  17. 'convnext_xlarge',
  18. 'convnext_xxlarge',
  19. 'convnext_xxlarge_320',
  20. 'vit_medium_patch16_gap_256',
  21. # exceeds GH runner memory limit
  22. 'ViT-bigG-14',
  23. 'ViT-e-14',
  24. 'mt5-xl-ViT-H-14',
  25. 'coca_base',
  26. 'coca_ViT-B-32',
  27. 'coca_roberta-ViT-B-32'
  28. })
  29. if 'OPEN_CLIP_TEST_REG_MODELS' in os.environ:
  30. external_model_list = os.environ['OPEN_CLIP_TEST_REG_MODELS']
  31. with open(external_model_list, 'r') as f:
  32. models_to_test = set(f.read().splitlines()).intersection(models_to_test)
  33. print(f"Selected models from {external_model_list}: {models_to_test}")
  34. # TODO: add "coca_ViT-B-32" onece https://github.com/pytorch/pytorch/issues/92073 gets fixed
  35. models_to_test = list(models_to_test)
  36. models_to_test.sort()
  37. models_to_test = [(model_name, False) for model_name in models_to_test]
  38. models_to_jit_test = {"ViT-B-32"}
  39. models_to_jit_test = list(models_to_jit_test)
  40. models_to_jit_test = [(model_name, True) for model_name in models_to_jit_test]
  41. models_to_test_fully = models_to_test + models_to_jit_test
  42. @pytest.mark.regression_test
  43. @pytest.mark.parametrize("model_name,jit", models_to_test_fully)
  44. def test_inference_with_data(
  45. model_name,
  46. jit,
  47. pretrained = None,
  48. pretrained_hf = False,
  49. precision = 'fp32',
  50. force_quick_gelu = False,
  51. ):
  52. util_test.seed_all()
  53. model, _, preprocess_val = open_clip.create_model_and_transforms(
  54. model_name,
  55. pretrained = pretrained,
  56. precision = precision,
  57. jit = jit,
  58. force_quick_gelu = force_quick_gelu,
  59. pretrained_hf = pretrained_hf
  60. )
  61. model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}'
  62. input_dir, output_dir = util_test.get_data_dirs()
  63. # text
  64. input_text_path = os.path.join(input_dir, 'random_text.pt')
  65. gt_text_path = os.path.join(output_dir, f'{model_id}_random_text.pt')
  66. if not os.path.isfile(input_text_path):
  67. pytest.skip(reason = f"missing test data, expected at {input_text_path}")
  68. if not os.path.isfile(gt_text_path):
  69. pytest.skip(reason = f"missing test data, expected at {gt_text_path}")
  70. input_text = torch.load(input_text_path)
  71. gt_text = torch.load(gt_text_path)
  72. y_text = util_test.inference_text(model, model_name, input_text)
  73. assert (y_text == gt_text).all(), f"text output differs @ {input_text_path}"
  74. # image
  75. image_size = model.visual.image_size
  76. if not isinstance(image_size, tuple):
  77. image_size = (image_size, image_size)
  78. input_image_path = os.path.join(input_dir, f'random_image_{image_size[0]}_{image_size[1]}.pt')
  79. gt_image_path = os.path.join(output_dir, f'{model_id}_random_image.pt')
  80. if not os.path.isfile(input_image_path):
  81. pytest.skip(reason = f"missing test data, expected at {input_image_path}")
  82. if not os.path.isfile(gt_image_path):
  83. pytest.skip(reason = f"missing test data, expected at {gt_image_path}")
  84. input_image = torch.load(input_image_path)
  85. gt_image = torch.load(gt_image_path)
  86. y_image = util_test.inference_image(model, preprocess_val, input_image)
  87. assert (y_image == gt_image).all(), f"image output differs @ {input_image_path}"
  88. if not jit:
  89. model.eval()
  90. model_out = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text)
  91. if type(model) not in [open_clip.CLIP, open_clip.CustomTextCLIP]:
  92. assert type(model_out) == dict
  93. else:
  94. model.output_dict = True
  95. model_out_dict = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text)
  96. assert (model_out_dict["image_features"] == model_out[0]).all()
  97. assert (model_out_dict["text_features"] == model_out[1]).all()
  98. assert (model_out_dict["logit_scale"] == model_out[2]).all()
  99. model.output_dict = None
  100. else:
  101. model, _, preprocess_val = open_clip.create_model_and_transforms(
  102. model_name,
  103. pretrained = pretrained,
  104. precision = precision,
  105. jit = False,
  106. force_quick_gelu = force_quick_gelu,
  107. pretrained_hf = pretrained_hf
  108. )
  109. test_model = util_test.TestWrapper(model, model_name, output_dict=False)
  110. test_model = torch.jit.script(test_model)
  111. model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text)
  112. assert model_out["test_output"].shape[-1] == 2
  113. test_model = util_test.TestWrapper(model, model_name, output_dict=True)
  114. test_model = torch.jit.script(test_model)
  115. model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text)
  116. assert model_out["test_output"].shape[-1] == 2