test_training_simple.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import os
  2. import sys
  3. import pytest
  4. from PIL import Image
  5. import torch
  6. from training.main import main
  7. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  8. if hasattr(torch._C, '_jit_set_profiling_executor'):
  9. # legacy executor is too slow to compile large models for unit tests
  10. # no need for the fusion performance here
  11. torch._C._jit_set_profiling_executor(True)
  12. torch._C._jit_set_profiling_mode(False)
  13. @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
  14. def test_training():
  15. main([
  16. '--save-frequency', '1',
  17. '--zeroshot-frequency', '1',
  18. '--dataset-type', "synthetic",
  19. '--train-num-samples', '16',
  20. '--warmup', '1',
  21. '--batch-size', '4',
  22. '--lr', '1e-3',
  23. '--wd', '0.1',
  24. '--epochs', '1',
  25. '--workers', '2',
  26. '--model', 'RN50'
  27. ])
  28. @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
  29. def test_training_coca():
  30. main([
  31. '--save-frequency', '1',
  32. '--zeroshot-frequency', '1',
  33. '--dataset-type', "synthetic",
  34. '--train-num-samples', '16',
  35. '--warmup', '1',
  36. '--batch-size', '4',
  37. '--lr', '1e-3',
  38. '--wd', '0.1',
  39. '--epochs', '1',
  40. '--workers', '2',
  41. '--model', 'coca_ViT-B-32'
  42. ])
  43. @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
  44. def test_training_mt5():
  45. main([
  46. '--save-frequency', '1',
  47. '--zeroshot-frequency', '1',
  48. '--dataset-type', "synthetic",
  49. '--train-num-samples', '16',
  50. '--warmup', '1',
  51. '--batch-size', '4',
  52. '--lr', '1e-3',
  53. '--wd', '0.1',
  54. '--epochs', '1',
  55. '--workers', '2',
  56. '--model', 'mt5-base-ViT-B-32',
  57. '--lock-text',
  58. '--lock-text-unlocked-layers', '2'
  59. ])
  60. @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
  61. def test_training_unfreezing_vit():
  62. main([
  63. '--save-frequency', '1',
  64. '--zeroshot-frequency', '1',
  65. '--dataset-type', "synthetic",
  66. '--train-num-samples', '16',
  67. '--warmup', '1',
  68. '--batch-size', '4',
  69. '--lr', '1e-3',
  70. '--wd', '0.1',
  71. '--epochs', '1',
  72. '--workers', '2',
  73. '--model', 'ViT-B-32',
  74. '--lock-image',
  75. '--lock-image-unlocked-groups', '5',
  76. '--accum-freq', '2'
  77. ])
  78. @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
  79. def test_training_clip_with_jit():
  80. main([
  81. '--save-frequency', '1',
  82. '--zeroshot-frequency', '1',
  83. '--dataset-type', "synthetic",
  84. '--train-num-samples', '16',
  85. '--warmup', '1',
  86. '--batch-size', '4',
  87. '--lr', '1e-3',
  88. '--wd', '0.1',
  89. '--epochs', '1',
  90. '--workers', '2',
  91. '--model', 'ViT-B-32',
  92. '--torchscript'
  93. ])