123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import collections
- import unittest
- import numpy as np
- from fairseq.data import ListDataset, ResamplingDataset
- class TestResamplingDataset(unittest.TestCase):
- def setUp(self):
- self.strings = ["ab", "c", "def", "ghij"]
- self.weights = [4.0, 2.0, 7.0, 1.5]
- self.size_ratio = 2
- self.dataset = ListDataset(
- self.strings, np.array([len(s) for s in self.strings])
- )
- def _test_common(self, resampling_dataset, iters):
- assert len(self.dataset) == len(self.strings) == len(self.weights)
- assert len(resampling_dataset) == self.size_ratio * len(self.strings)
- results = {"ordered_by_size": True, "max_distribution_diff": 0.0}
- totalfreqs = 0
- freqs = collections.defaultdict(int)
- for epoch_num in range(iters):
- resampling_dataset.set_epoch(epoch_num)
- indices = resampling_dataset.ordered_indices()
- assert len(indices) == len(resampling_dataset)
- prev_size = -1
- for i in indices:
- cur_size = resampling_dataset.size(i)
- # Make sure indices map to same sequences within an epoch
- assert resampling_dataset[i] == resampling_dataset[i]
- # Make sure length of sequence is correct
- assert cur_size == len(resampling_dataset[i])
- freqs[resampling_dataset[i]] += 1
- totalfreqs += 1
- if prev_size > cur_size:
- results["ordered_by_size"] = False
- prev_size = cur_size
- assert set(freqs.keys()) == set(self.strings)
- for s, weight in zip(self.strings, self.weights):
- freq = freqs[s] / totalfreqs
- expected_freq = weight / sum(self.weights)
- results["max_distribution_diff"] = max(
- results["max_distribution_diff"], abs(expected_freq - freq)
- )
- return results
- def test_resampling_dataset_batch_by_size_false(self):
- resampling_dataset = ResamplingDataset(
- self.dataset,
- self.weights,
- size_ratio=self.size_ratio,
- batch_by_size=False,
- seed=0,
- )
- results = self._test_common(resampling_dataset, iters=1000)
- # For batch_by_size = False, the batches should be returned in
- # arbitrary order of size.
- assert not results["ordered_by_size"]
- # Allow tolerance in distribution error of 2%.
- assert results["max_distribution_diff"] < 0.02
- def test_resampling_dataset_batch_by_size_true(self):
- resampling_dataset = ResamplingDataset(
- self.dataset,
- self.weights,
- size_ratio=self.size_ratio,
- batch_by_size=True,
- seed=0,
- )
- results = self._test_common(resampling_dataset, iters=1000)
- # For batch_by_size = True, the batches should be returned in
- # increasing order of size.
- assert results["ordered_by_size"]
- # Allow tolerance in distribution error of 2%.
- assert results["max_distribution_diff"] < 0.02
- if __name__ == "__main__":
- unittest.main()
|