extract.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. #!/usr/bin/env python3
  2. #
  3. # Copyright (c) Facebook, Inc. and its affiliates.
  4. #
  5. # This source code is licensed under the MIT license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """Extracts random constraints from reference files."""
  8. import argparse
  9. import random
  10. import sys
  11. def get_phrase(words, index, length):
  12. assert index < len(words) - length + 1
  13. phr = " ".join(words[index : index + length])
  14. for i in range(index, index + length):
  15. words.pop(index)
  16. return phr
  17. def main(args):
  18. if args.seed:
  19. random.seed(args.seed)
  20. for line in sys.stdin:
  21. constraints = []
  22. def add_constraint(constraint):
  23. constraints.append(constraint)
  24. source = line.rstrip()
  25. if "\t" in line:
  26. source, target = line.split("\t")
  27. if args.add_sos:
  28. target = f"<s> {target}"
  29. if args.add_eos:
  30. target = f"{target} </s>"
  31. if len(target.split()) >= args.len:
  32. words = [target]
  33. num = args.number
  34. choices = {}
  35. for i in range(num):
  36. if len(words) == 0:
  37. break
  38. segmentno = random.choice(range(len(words)))
  39. segment = words.pop(segmentno)
  40. tokens = segment.split()
  41. phrase_index = random.choice(range(len(tokens)))
  42. choice = " ".join(
  43. tokens[phrase_index : min(len(tokens), phrase_index + args.len)]
  44. )
  45. for j in range(
  46. phrase_index, min(len(tokens), phrase_index + args.len)
  47. ):
  48. tokens.pop(phrase_index)
  49. if phrase_index > 0:
  50. words.append(" ".join(tokens[0:phrase_index]))
  51. if phrase_index + 1 < len(tokens):
  52. words.append(" ".join(tokens[phrase_index:]))
  53. choices[target.find(choice)] = choice
  54. # mask out with spaces
  55. target = target.replace(choice, " " * len(choice), 1)
  56. for key in sorted(choices.keys()):
  57. add_constraint(choices[key])
  58. print(source, *constraints, sep="\t")
  59. if __name__ == "__main__":
  60. parser = argparse.ArgumentParser()
  61. parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases")
  62. parser.add_argument("--len", "-l", type=int, default=1, help="phrase length")
  63. parser.add_argument(
  64. "--add-sos", default=False, action="store_true", help="add <s> token"
  65. )
  66. parser.add_argument(
  67. "--add-eos", default=False, action="store_true", help="add </s> token"
  68. )
  69. parser.add_argument("--seed", "-s", default=0, type=int)
  70. args = parser.parse_args()
  71. main(args)