average_checkpoints.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. #
  4. # This source code is licensed under the MIT license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import collections
  8. import os
  9. import re
  10. import torch
  11. from fairseq.file_io import PathManager
  12. def average_checkpoints(inputs):
  13. """Loads checkpoints from inputs and returns a model with averaged weights.
  14. Args:
  15. inputs: An iterable of string paths of checkpoints to load from.
  16. Returns:
  17. A dict of string keys mapping to various values. The 'model' key
  18. from the returned dict should correspond to an OrderedDict mapping
  19. string parameter names to torch Tensors.
  20. """
  21. params_dict = collections.OrderedDict()
  22. params_keys = None
  23. new_state = None
  24. num_models = len(inputs)
  25. for fpath in inputs:
  26. with PathManager.open(fpath, "rb") as f:
  27. state = torch.load(
  28. f,
  29. map_location=(
  30. lambda s, _: torch.serialization.default_restore_location(s, "cpu")
  31. ),
  32. )
  33. # Copies over the settings from the first checkpoint
  34. if new_state is None:
  35. new_state = state
  36. model_params = state["model"]
  37. model_params_keys = list(model_params.keys())
  38. if params_keys is None:
  39. params_keys = model_params_keys
  40. elif params_keys != model_params_keys:
  41. raise KeyError(
  42. "For checkpoint {}, expected list of params: {}, "
  43. "but found: {}".format(f, params_keys, model_params_keys)
  44. )
  45. for k in params_keys:
  46. p = model_params[k]
  47. if isinstance(p, torch.HalfTensor):
  48. p = p.float()
  49. if k not in params_dict:
  50. params_dict[k] = p.clone()
  51. # NOTE: clone() is needed in case of p is a shared parameter
  52. else:
  53. params_dict[k] += p
  54. averaged_params = collections.OrderedDict()
  55. for k, v in params_dict.items():
  56. averaged_params[k] = v
  57. if averaged_params[k].is_floating_point():
  58. averaged_params[k].div_(num_models)
  59. else:
  60. averaged_params[k] //= num_models
  61. new_state["model"] = averaged_params
  62. return new_state
  63. def last_n_checkpoints(paths, n, update_based, upper_bound=None):
  64. assert len(paths) == 1
  65. path = paths[0]
  66. if update_based:
  67. pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt")
  68. else:
  69. pt_regexp = re.compile(r"checkpoint(\d+)\.pt")
  70. files = PathManager.ls(path)
  71. entries = []
  72. for f in files:
  73. m = pt_regexp.fullmatch(f)
  74. if m is not None:
  75. sort_key = int(m.group(1))
  76. if upper_bound is None or sort_key <= upper_bound:
  77. entries.append((sort_key, m.group(0)))
  78. if len(entries) < n:
  79. raise Exception(
  80. "Found {} checkpoint files but need at least {}", len(entries), n
  81. )
  82. return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
  83. def main():
  84. parser = argparse.ArgumentParser(
  85. description="Tool to average the params of input checkpoints to "
  86. "produce a new checkpoint",
  87. )
  88. # fmt: off
  89. parser.add_argument('--inputs', required=True, nargs='+',
  90. help='Input checkpoint file paths.')
  91. parser.add_argument('--output', required=True, metavar='FILE',
  92. help='Write the new checkpoint containing the averaged weights to this path.')
  93. num_group = parser.add_mutually_exclusive_group()
  94. num_group.add_argument('--num-epoch-checkpoints', type=int,
  95. help='if set, will try to find checkpoints with names checkpoint_xx.pt in the '
  96. 'path specified by input, and average last this many of them.')
  97. num_group.add_argument('--num-update-checkpoints', type=int,
  98. help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by'
  99. ' input, and average last this many of them.')
  100. num_group.add_argument('--num-best-checkpoints', type=int, default=0,
  101. help='if set, will try to find checkpoints with names checkpoint_best_ee_xx.pt in the path specified by'
  102. ' input, and average last this many of them.')
  103. parser.add_argument('--checkpoint-upper-bound', type=int,
  104. help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, '
  105. 'when using --num-update-checkpoints, this will set an upper bound on which update to use'
  106. 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be'
  107. ' averaged.'
  108. 'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would'
  109. ' be averaged assuming --save-interval-updates 500'
  110. )
  111. # fmt: on
  112. args = parser.parse_args()
  113. print(args)
  114. num = None
  115. is_update_based = False
  116. if args.num_update_checkpoints is not None:
  117. num = args.num_update_checkpoints
  118. is_update_based = True
  119. elif args.num_epoch_checkpoints is not None:
  120. num = args.num_epoch_checkpoints
  121. assert args.checkpoint_upper_bound is None or (
  122. args.num_epoch_checkpoints is not None
  123. or args.num_update_checkpoints is not None
  124. ), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints"
  125. assert (
  126. args.num_epoch_checkpoints is None or args.num_update_checkpoints is None
  127. ), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints"
  128. if num is not None:
  129. args.inputs = last_n_checkpoints(
  130. args.inputs,
  131. num,
  132. is_update_based,
  133. upper_bound=args.checkpoint_upper_bound,
  134. )
  135. print("averaging checkpoints: ", args.inputs)
  136. if args.num_best_checkpoints > 0:
  137. args.inputs = list(
  138. sorted(
  139. args.inputs,
  140. key=lambda x: float(
  141. os.path.basename(x).split("_")[-1].replace(".pt", "")
  142. ),
  143. )
  144. )
  145. args.inputs = args.inputs[: args.num_best_checkpoints]
  146. for path in args.inputs:
  147. print(os.path.basename(path))
  148. new_state = average_checkpoints(args.inputs)
  149. with PathManager.open(args.output, "wb") as f:
  150. torch.save(new_state, f)
  151. print("Finished writing averaged checkpoint to {}".format(args.output))
  152. if __name__ == "__main__":
  153. main()