rm_pt.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. #
  4. # This source code is licensed under the MIT license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import os
  8. import re
  9. import shutil
  10. import sys
  11. pt_regexp = re.compile(r"checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt")
  12. pt_regexp_epoch_based = re.compile(r"checkpoint(\d+)\.pt")
  13. pt_regexp_update_based = re.compile(r"checkpoint_\d+_(\d+)\.pt")
  14. def parse_checkpoints(files):
  15. entries = []
  16. for f in files:
  17. m = pt_regexp_epoch_based.fullmatch(f)
  18. if m is not None:
  19. entries.append((int(m.group(1)), m.group(0)))
  20. else:
  21. m = pt_regexp_update_based.fullmatch(f)
  22. if m is not None:
  23. entries.append((int(m.group(1)), m.group(0)))
  24. return entries
  25. def last_n_checkpoints(files, n):
  26. entries = parse_checkpoints(files)
  27. return [x[1] for x in sorted(entries, reverse=True)[:n]]
  28. def every_n_checkpoints(files, n):
  29. entries = parse_checkpoints(files)
  30. return [x[1] for x in sorted(sorted(entries)[::-n])]
  31. def main():
  32. parser = argparse.ArgumentParser(
  33. description=(
  34. "Recursively delete checkpoint files from `root_dir`, "
  35. "but preserve checkpoint_best.pt and checkpoint_last.pt"
  36. )
  37. )
  38. parser.add_argument("root_dirs", nargs="*")
  39. parser.add_argument(
  40. "--save-last", type=int, default=0, help="number of last checkpoints to save"
  41. )
  42. parser.add_argument(
  43. "--save-every", type=int, default=0, help="interval of checkpoints to save"
  44. )
  45. parser.add_argument(
  46. "--preserve-test",
  47. action="store_true",
  48. help="preserve checkpoints in dirs that start with test_ prefix (default: delete them)",
  49. )
  50. parser.add_argument(
  51. "--delete-best", action="store_true", help="delete checkpoint_best.pt"
  52. )
  53. parser.add_argument(
  54. "--delete-last", action="store_true", help="delete checkpoint_last.pt"
  55. )
  56. parser.add_argument(
  57. "--no-dereference", action="store_true", help="don't dereference symlinks"
  58. )
  59. args = parser.parse_args()
  60. files_to_desymlink = []
  61. files_to_preserve = []
  62. files_to_delete = []
  63. for root_dir in args.root_dirs:
  64. for root, _subdirs, files in os.walk(root_dir):
  65. if args.save_last > 0:
  66. to_save = last_n_checkpoints(files, args.save_last)
  67. else:
  68. to_save = []
  69. if args.save_every > 0:
  70. to_save += every_n_checkpoints(files, args.save_every)
  71. for file in files:
  72. if not pt_regexp.fullmatch(file):
  73. continue
  74. full_path = os.path.join(root, file)
  75. if (
  76. not os.path.basename(root).startswith("test_") or args.preserve_test
  77. ) and (
  78. (file == "checkpoint_last.pt" and not args.delete_last)
  79. or (file == "checkpoint_best.pt" and not args.delete_best)
  80. or file in to_save
  81. ):
  82. if os.path.islink(full_path) and not args.no_dereference:
  83. files_to_desymlink.append(full_path)
  84. else:
  85. files_to_preserve.append(full_path)
  86. else:
  87. files_to_delete.append(full_path)
  88. if len(files_to_desymlink) == 0 and len(files_to_delete) == 0:
  89. print("Nothing to do.")
  90. sys.exit(0)
  91. files_to_desymlink = sorted(files_to_desymlink)
  92. files_to_preserve = sorted(files_to_preserve)
  93. files_to_delete = sorted(files_to_delete)
  94. print("Operations to perform (in order):")
  95. if len(files_to_desymlink) > 0:
  96. for file in files_to_desymlink:
  97. print(" - preserve (and dereference symlink): " + file)
  98. if len(files_to_preserve) > 0:
  99. for file in files_to_preserve:
  100. print(" - preserve: " + file)
  101. if len(files_to_delete) > 0:
  102. for file in files_to_delete:
  103. print(" - delete: " + file)
  104. while True:
  105. resp = input("Continue? (Y/N): ")
  106. if resp.strip().lower() == "y":
  107. break
  108. elif resp.strip().lower() == "n":
  109. sys.exit(0)
  110. print("Executing...")
  111. if len(files_to_desymlink) > 0:
  112. for file in files_to_desymlink:
  113. realpath = os.path.realpath(file)
  114. print("rm " + file)
  115. os.remove(file)
  116. print("cp {} {}".format(realpath, file))
  117. shutil.copyfile(realpath, file)
  118. if len(files_to_delete) > 0:
  119. for file in files_to_delete:
  120. print("rm " + file)
  121. os.remove(file)
  122. if __name__ == "__main__":
  123. main()