123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- #!/usr/bin/env python3
- #
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- """Extracts random constraints from reference files."""
- import argparse
- import random
- import sys
- def get_phrase(words, index, length):
- assert index < len(words) - length + 1
- phr = " ".join(words[index : index + length])
- for i in range(index, index + length):
- words.pop(index)
- return phr
- def main(args):
- if args.seed:
- random.seed(args.seed)
- for line in sys.stdin:
- constraints = []
- def add_constraint(constraint):
- constraints.append(constraint)
- source = line.rstrip()
- if "\t" in line:
- source, target = line.split("\t")
- if args.add_sos:
- target = f"<s> {target}"
- if args.add_eos:
- target = f"{target} </s>"
- if len(target.split()) >= args.len:
- words = [target]
- num = args.number
- choices = {}
- for i in range(num):
- if len(words) == 0:
- break
- segmentno = random.choice(range(len(words)))
- segment = words.pop(segmentno)
- tokens = segment.split()
- phrase_index = random.choice(range(len(tokens)))
- choice = " ".join(
- tokens[phrase_index : min(len(tokens), phrase_index + args.len)]
- )
- for j in range(
- phrase_index, min(len(tokens), phrase_index + args.len)
- ):
- tokens.pop(phrase_index)
- if phrase_index > 0:
- words.append(" ".join(tokens[0:phrase_index]))
- if phrase_index + 1 < len(tokens):
- words.append(" ".join(tokens[phrase_index:]))
- choices[target.find(choice)] = choice
- # mask out with spaces
- target = target.replace(choice, " " * len(choice), 1)
- for key in sorted(choices.keys()):
- add_constraint(choices[key])
- print(source, *constraints, sep="\t")
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases")
- parser.add_argument("--len", "-l", type=int, default=1, help="phrase length")
- parser.add_argument(
- "--add-sos", default=False, action="store_true", help="add <s> token"
- )
- parser.add_argument(
- "--add-eos", default=False, action="store_true", help="add </s> token"
- )
- parser.add_argument("--seed", "-s", default=0, type=int)
- args = parser.parse_args()
- main(args)
|