hubconf.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from clip.clip import tokenize as _tokenize, load as _load, available_models as _available_models
  2. import re
  3. import string
  4. dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"]
  5. # For compatibility (cannot include special characters in function name)
  6. model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()}
  7. def _create_hub_entrypoint(model):
  8. def entrypoint(**kwargs):
  9. return _load(model, **kwargs)
  10. entrypoint.__doc__ = f"""Loads the {model} CLIP model
  11. Parameters
  12. ----------
  13. device : Union[str, torch.device]
  14. The device to put the loaded model
  15. jit : bool
  16. Whether to load the optimized JIT model or more hackable non-JIT model (default).
  17. download_root: str
  18. path to download the model files; by default, it uses "~/.cache/clip"
  19. Returns
  20. -------
  21. model : torch.nn.Module
  22. The {model} CLIP model
  23. preprocess : Callable[[PIL.Image], torch.Tensor]
  24. A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
  25. """
  26. return entrypoint
  27. def tokenize():
  28. return _tokenize
  29. _entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()}
  30. globals().update(_entrypoints)