test_download_pretrained.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import requests
  2. import torch
  3. from PIL import Image
  4. import hashlib
  5. import tempfile
  6. import unittest
  7. from io import BytesIO
  8. from pathlib import Path
  9. from unittest.mock import patch
  10. from urllib3 import HTTPResponse
  11. from urllib3._collections import HTTPHeaderDict
  12. import open_clip
  13. from open_clip.pretrained import download_pretrained_from_url
  14. class DownloadPretrainedTests(unittest.TestCase):
  15. def create_response(self, data, status_code=200, content_type='application/octet-stream'):
  16. fp = BytesIO(data)
  17. headers = HTTPHeaderDict({
  18. 'Content-Type': content_type,
  19. 'Content-Length': str(len(data))
  20. })
  21. raw = HTTPResponse(fp, preload_content=False, headers=headers, status=status_code)
  22. return raw
  23. @patch('open_clip.pretrained.urllib')
  24. def test_download_pretrained_from_url_from_openaipublic(self, urllib):
  25. file_contents = b'pretrained model weights'
  26. expected_hash = hashlib.sha256(file_contents).hexdigest()
  27. urllib.request.urlopen.return_value = self.create_response(file_contents)
  28. with tempfile.TemporaryDirectory() as root:
  29. url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
  30. download_pretrained_from_url(url, root)
  31. urllib.request.urlopen.assert_called_once()
  32. @patch('open_clip.pretrained.urllib')
  33. def test_download_pretrained_from_url_from_openaipublic_corrupted(self, urllib):
  34. file_contents = b'pretrained model weights'
  35. expected_hash = hashlib.sha256(file_contents).hexdigest()
  36. urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model')
  37. with tempfile.TemporaryDirectory() as root:
  38. url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
  39. with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'):
  40. download_pretrained_from_url(url, root)
  41. urllib.request.urlopen.assert_called_once()
  42. @patch('open_clip.pretrained.urllib')
  43. def test_download_pretrained_from_url_from_openaipublic_valid_cache(self, urllib):
  44. file_contents = b'pretrained model weights'
  45. expected_hash = hashlib.sha256(file_contents).hexdigest()
  46. urllib.request.urlopen.return_value = self.create_response(file_contents)
  47. with tempfile.TemporaryDirectory() as root:
  48. local_file = Path(root) / 'RN50.pt'
  49. local_file.write_bytes(file_contents)
  50. url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
  51. download_pretrained_from_url(url, root)
  52. urllib.request.urlopen.assert_not_called()
  53. @patch('open_clip.pretrained.urllib')
  54. def test_download_pretrained_from_url_from_openaipublic_corrupted_cache(self, urllib):
  55. file_contents = b'pretrained model weights'
  56. expected_hash = hashlib.sha256(file_contents).hexdigest()
  57. urllib.request.urlopen.return_value = self.create_response(file_contents)
  58. with tempfile.TemporaryDirectory() as root:
  59. local_file = Path(root) / 'RN50.pt'
  60. local_file.write_bytes(b'corrupted pretrained model')
  61. url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
  62. download_pretrained_from_url(url, root)
  63. urllib.request.urlopen.assert_called_once()
  64. @patch('open_clip.pretrained.urllib')
  65. def test_download_pretrained_from_url_from_mlfoundations(self, urllib):
  66. file_contents = b'pretrained model weights'
  67. expected_hash = hashlib.sha256(file_contents).hexdigest()[:8]
  68. urllib.request.urlopen.return_value = self.create_response(file_contents)
  69. with tempfile.TemporaryDirectory() as root:
  70. url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt'
  71. download_pretrained_from_url(url, root)
  72. urllib.request.urlopen.assert_called_once()
  73. @patch('open_clip.pretrained.urllib')
  74. def test_download_pretrained_from_url_from_mlfoundations_corrupted(self, urllib):
  75. file_contents = b'pretrained model weights'
  76. expected_hash = hashlib.sha256(file_contents).hexdigest()[:8]
  77. urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model')
  78. with tempfile.TemporaryDirectory() as root:
  79. url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt'
  80. with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'):
  81. download_pretrained_from_url(url, root)
  82. urllib.request.urlopen.assert_called_once()
  83. @patch('open_clip.pretrained.urllib')
  84. def test_download_pretrained_from_hfh(self, urllib):
  85. model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:hf-internal-testing/tiny-open-clip-model')
  86. tokenizer = open_clip.get_tokenizer('hf-hub:hf-internal-testing/tiny-open-clip-model')
  87. img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png"
  88. image = preprocess(Image.open(requests.get(img_url, stream=True).raw)).unsqueeze(0)
  89. text = tokenizer(["a diagram", "a dog", "a cat"])
  90. with torch.no_grad():
  91. image_features = model.encode_image(image)
  92. text_features = model.encode_text(text)
  93. image_features /= image_features.norm(dim=-1, keepdim=True)
  94. text_features /= text_features.norm(dim=-1, keepdim=True)
  95. text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
  96. self.assertTrue(torch.allclose(text_probs, torch.tensor([[0.0597, 0.6349, 0.3053]]), 1e-3))