test_num_shards.py 757 B

1234567891011121314151617181920
  1. import pytest
  2. from training.data import get_dataset_size
  3. @pytest.mark.parametrize(
  4. "shards,expected_size",
  5. [
  6. ('/path/to/shard.tar', 1),
  7. ('/path/to/shard_{000..000}.tar', 1),
  8. ('/path/to/shard_{000..009}.tar', 10),
  9. ('/path/to/shard_{000..009}_{000..009}.tar', 100),
  10. ('/path/to/shard.tar::/path/to/other_shard_{000..009}.tar', 11),
  11. ('/path/to/shard_{000..009}.tar::/path/to/other_shard_{000..009}.tar', 20),
  12. (['/path/to/shard.tar'], 1),
  13. (['/path/to/shard.tar', '/path/to/other_shard.tar'], 2),
  14. ]
  15. )
  16. def test_num_shards(shards, expected_size):
  17. _, size = get_dataset_size(shards)
  18. assert size == expected_size, f'Expected {expected_size} for {shards} but found {size} instead.'