spm_encode.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #!/usr/bin/env python
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. from __future__ import absolute_import, division, print_function, unicode_literals
  8. import argparse
  9. import contextlib
  10. import sys
  11. import sentencepiece as spm
  12. def main():
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument(
  15. "--model", required=True, help="sentencepiece model to use for encoding"
  16. )
  17. parser.add_argument(
  18. "--inputs", nargs="+", default=["-"], help="input files to filter/encode"
  19. )
  20. parser.add_argument(
  21. "--outputs", nargs="+", default=["-"], help="path to save encoded outputs"
  22. )
  23. parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
  24. parser.add_argument(
  25. "--min-len",
  26. type=int,
  27. metavar="N",
  28. help="filter sentence pairs with fewer than N tokens",
  29. )
  30. parser.add_argument(
  31. "--max-len",
  32. type=int,
  33. metavar="N",
  34. help="filter sentence pairs with more than N tokens",
  35. )
  36. args = parser.parse_args()
  37. assert len(args.inputs) == len(
  38. args.outputs
  39. ), "number of input and output paths should match"
  40. sp = spm.SentencePieceProcessor()
  41. sp.Load(args.model)
  42. if args.output_format == "piece":
  43. def encode(input):
  44. return sp.EncodeAsPieces(input)
  45. elif args.output_format == "id":
  46. def encode(input):
  47. return list(map(str, sp.EncodeAsIds(input)))
  48. else:
  49. raise NotImplementedError
  50. if args.min_len is not None or args.max_len is not None:
  51. def valid(line):
  52. return (args.min_len is None or len(line) >= args.min_len) and (
  53. args.max_len is None or len(line) <= args.max_len
  54. )
  55. else:
  56. def valid(lines):
  57. return True
  58. with contextlib.ExitStack() as stack:
  59. inputs = [
  60. stack.enter_context(open(input, "r", encoding="utf-8"))
  61. if input != "-"
  62. else sys.stdin
  63. for input in args.inputs
  64. ]
  65. outputs = [
  66. stack.enter_context(open(output, "w", encoding="utf-8"))
  67. if output != "-"
  68. else sys.stdout
  69. for output in args.outputs
  70. ]
  71. stats = {
  72. "num_empty": 0,
  73. "num_filtered": 0,
  74. }
  75. def encode_line(line):
  76. line = line.strip()
  77. if len(line) > 0:
  78. line = encode(line)
  79. if valid(line):
  80. return line
  81. else:
  82. stats["num_filtered"] += 1
  83. else:
  84. stats["num_empty"] += 1
  85. return None
  86. for i, lines in enumerate(zip(*inputs), start=1):
  87. enc_lines = list(map(encode_line, lines))
  88. if not any(enc_line is None for enc_line in enc_lines):
  89. for enc_line, output_h in zip(enc_lines, outputs):
  90. print(" ".join(enc_line), file=output_h)
  91. if i % 10000 == 0:
  92. print("processed {} lines".format(i), file=sys.stderr)
  93. print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
  94. print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
  95. if __name__ == "__main__":
  96. main()