123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- #!/usr/bin/env python3
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import argparse
- import os
- import re
- import shutil
- import sys
- pt_regexp = re.compile(r"checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt")
- pt_regexp_epoch_based = re.compile(r"checkpoint(\d+)\.pt")
- pt_regexp_update_based = re.compile(r"checkpoint_\d+_(\d+)\.pt")
- def parse_checkpoints(files):
- entries = []
- for f in files:
- m = pt_regexp_epoch_based.fullmatch(f)
- if m is not None:
- entries.append((int(m.group(1)), m.group(0)))
- else:
- m = pt_regexp_update_based.fullmatch(f)
- if m is not None:
- entries.append((int(m.group(1)), m.group(0)))
- return entries
- def last_n_checkpoints(files, n):
- entries = parse_checkpoints(files)
- return [x[1] for x in sorted(entries, reverse=True)[:n]]
- def every_n_checkpoints(files, n):
- entries = parse_checkpoints(files)
- return [x[1] for x in sorted(sorted(entries)[::-n])]
- def main():
- parser = argparse.ArgumentParser(
- description=(
- "Recursively delete checkpoint files from `root_dir`, "
- "but preserve checkpoint_best.pt and checkpoint_last.pt"
- )
- )
- parser.add_argument("root_dirs", nargs="*")
- parser.add_argument(
- "--save-last", type=int, default=0, help="number of last checkpoints to save"
- )
- parser.add_argument(
- "--save-every", type=int, default=0, help="interval of checkpoints to save"
- )
- parser.add_argument(
- "--preserve-test",
- action="store_true",
- help="preserve checkpoints in dirs that start with test_ prefix (default: delete them)",
- )
- parser.add_argument(
- "--delete-best", action="store_true", help="delete checkpoint_best.pt"
- )
- parser.add_argument(
- "--delete-last", action="store_true", help="delete checkpoint_last.pt"
- )
- parser.add_argument(
- "--no-dereference", action="store_true", help="don't dereference symlinks"
- )
- args = parser.parse_args()
- files_to_desymlink = []
- files_to_preserve = []
- files_to_delete = []
- for root_dir in args.root_dirs:
- for root, _subdirs, files in os.walk(root_dir):
- if args.save_last > 0:
- to_save = last_n_checkpoints(files, args.save_last)
- else:
- to_save = []
- if args.save_every > 0:
- to_save += every_n_checkpoints(files, args.save_every)
- for file in files:
- if not pt_regexp.fullmatch(file):
- continue
- full_path = os.path.join(root, file)
- if (
- not os.path.basename(root).startswith("test_") or args.preserve_test
- ) and (
- (file == "checkpoint_last.pt" and not args.delete_last)
- or (file == "checkpoint_best.pt" and not args.delete_best)
- or file in to_save
- ):
- if os.path.islink(full_path) and not args.no_dereference:
- files_to_desymlink.append(full_path)
- else:
- files_to_preserve.append(full_path)
- else:
- files_to_delete.append(full_path)
- if len(files_to_desymlink) == 0 and len(files_to_delete) == 0:
- print("Nothing to do.")
- sys.exit(0)
- files_to_desymlink = sorted(files_to_desymlink)
- files_to_preserve = sorted(files_to_preserve)
- files_to_delete = sorted(files_to_delete)
- print("Operations to perform (in order):")
- if len(files_to_desymlink) > 0:
- for file in files_to_desymlink:
- print(" - preserve (and dereference symlink): " + file)
- if len(files_to_preserve) > 0:
- for file in files_to_preserve:
- print(" - preserve: " + file)
- if len(files_to_delete) > 0:
- for file in files_to_delete:
- print(" - delete: " + file)
- while True:
- resp = input("Continue? (Y/N): ")
- if resp.strip().lower() == "y":
- break
- elif resp.strip().lower() == "n":
- sys.exit(0)
- print("Executing...")
- if len(files_to_desymlink) > 0:
- for file in files_to_desymlink:
- realpath = os.path.realpath(file)
- print("rm " + file)
- os.remove(file)
- print("cp {} {}".format(realpath, file))
- shutil.copyfile(realpath, file)
- if len(files_to_delete) > 0:
- for file in files_to_delete:
- print("rm " + file)
- os.remove(file)
- if __name__ == "__main__":
- main()
|