test_hf_hub.py 796 B

1234567891011121314151617181920212223242526272829
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. #
  4. # This source code is licensed under the MIT license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import unittest
  7. import torch
  8. try:
  9. import huggingface_hub
  10. except ImportError:
  11. huggingface_hub = None
  12. from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
  13. @unittest.skipIf(not huggingface_hub, "Requires huggingface_hub install")
  14. class TestHuggingFaceHub(unittest.TestCase):
  15. @torch.no_grad()
  16. def test_hf_fastspeech2(self):
  17. hf_model_id = "facebook/fastspeech2-en-ljspeech"
  18. models, cfg, task = load_model_ensemble_and_task_from_hf_hub(hf_model_id)
  19. self.assertTrue(len(models) > 0)
  20. if __name__ == "__main__":
  21. unittest.main()