main.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import os
  2. import sys
  3. import time
  4. import hashlib
  5. import asyncssh
  6. import asyncio
  7. from tqdm import tqdm
  8. from loguru import logger
  9. from file import get_files_to_upload
  10. # 配置日志输出到文件和控制台
  11. logger.add("output.log")
  12. logger.add(sys.stderr, level="WARNING")
  13. # 上传文件的函数
  14. async def upload_file(local_path, local_md5, remote_path, sftp):
  15. file_size = os.path.getsize(local_path)
  16. async with await sftp.open(remote_path, 'wb') as file:
  17. bytes_uploaded = 0
  18. with open(local_path, 'rb') as local_file:
  19. with tqdm(total=file_size, unit='B', unit_scale=True, ncols=80,
  20. desc=os.path.basename(local_path)) as progress_bar:
  21. while True:
  22. chunk = local_file.read(65536) # 读取文件块
  23. if not chunk:
  24. break
  25. await file.write(chunk) # 写入文件块
  26. bytes_uploaded += len(chunk)
  27. progress_bar.update(len(chunk))
  28. if bytes_uploaded == file_size:
  29. remote_md5 = await asyncio.wait_for(calculate_remote_md5(sftp, remote_path), timeout=calculate_timeout) # 设置超时时间为5秒
  30. if remote_md5 == local_md5:
  31. await write_remote_meta_md5(remote_path, remote_md5, sftp)
  32. logger.info("文件上传成功")
  33. else:
  34. logger.info("本地与远程MD5不一致,请检查上传过程中是否发生错误!")
  35. else:
  36. logger.info("文件上传失败,请检查文件后再试验。")
  37. # 递归创建远程目录
  38. async def create_remote_directory_recursive(directory, sftp):
  39. parent_dir = os.path.dirname(directory)
  40. try:
  41. await sftp.stat(parent_dir)
  42. except asyncssh.SFTPError as e:
  43. if "No such file" in str(e):
  44. await create_remote_directory_recursive(parent_dir, sftp)
  45. else:
  46. logger.error(f"检查远程目录失败: {parent_dir}. 错误信息: {e}")
  47. return
  48. try:
  49. await sftp.mkdir(directory)
  50. except asyncssh.SFTPError as e:
  51. logger.error(f"创建远程目录失败: {directory}. 错误信息: {e}")
  52. # 创建一个锁对象
  53. lock = asyncio.Lock()
  54. # 创建目录并上传文件的函数
  55. async def create_directory_and_upload(local_path, local_md5, remote_path, sftp):
  56. directory = os.path.dirname(remote_path)
  57. # 检查远程目录是否存在,如果不存在则创建
  58. try:
  59. await lock.acquire() # 获取锁对象
  60. attempts = 5 # 尝试次数
  61. while attempts > 0:
  62. try:
  63. await sftp.stat(directory)
  64. break # 目录已存在,跳出循环
  65. except asyncssh.SFTPError as e:
  66. if "No such file" in str(e):
  67. await create_remote_directory_recursive(directory, sftp)
  68. break # 目录创建成功,跳出循环
  69. elif "Failure" in str(e): # 这里可以根据实际情况修改判断条件
  70. await asyncio.sleep(5) # 等待5秒
  71. attempts -= 1
  72. else:
  73. logger.error(f"检查远程目录失败: {directory}. 错误信息: {e}")
  74. return
  75. finally:
  76. lock.release()
  77. # 上传文件
  78. for attempt in range(3): # 最多尝试3次上传操作
  79. try:
  80. await upload_file(local_path, local_md5, remote_path, sftp)
  81. break # 上传成功,退出循环
  82. except Exception as e:
  83. logger.error(f"上传文件失败: {os.path.basename(local_path)}。错误信息: {e}")
  84. if attempt < 2:
  85. logger.info("重新尝试上传文件...")
  86. else:
  87. logger.info("多次上传文件失败.")
  88. return
  89. async def compare_md5(local_path, remote_path, sftp):
  90. logger.info(f"正在比对文件:{os.path.basename(local_path)}中")
  91. try:
  92. local_md5 = calculate_md5(local_path)
  93. # 获取远程meta文件中的md5
  94. remote_meta_md5 = await get_remote_meta_md5(remote_path, sftp)
  95. if remote_meta_md5 and local_md5 == remote_meta_md5:
  96. logger.info("Meta文件校验成功 - 跳过远程md5计算。")
  97. return
  98. # 计算远程MD5码,并设置较长的超时时间
  99. remote_md5 = await asyncio.wait_for(calculate_remote_md5(sftp, remote_path), timeout=calculate_timeout) # 设置超时时间为5秒
  100. if remote_md5 is None:
  101. # 执行上传操作
  102. logger.info(f'准备文件上传 - 文件名为:{os.path.basename(local_path)}')
  103. await create_directory_and_upload(local_path, local_md5, remote_path, sftp)
  104. elif local_md5 != remote_md5:
  105. logger.info(f"MD5校验失败 - 校验结果为:{local_md5} - {remote_md5}")
  106. # 删除远程文件
  107. for attempt in range(3): # 最多尝试3次删除操作
  108. try:
  109. await sftp.remove(remote_path)
  110. logger.info(f'删除远程文件成功:{remote_path}')
  111. break # 删除成功,退出循环
  112. except Exception as e:
  113. logger.error(f"删除远程文件失败: {remote_path}. 错误信息: {e}")
  114. if attempt < 2:
  115. logger.info("重新尝试删除远程文件...")
  116. else:
  117. return
  118. # 执行上传操作
  119. logger.info(f'准备文件上传 - 文件名为:{os.path.basename(local_path)}')
  120. await create_directory_and_upload(local_path, local_md5, remote_path, sftp)
  121. else:
  122. logger.info(f"MD5码校验成功 - 校验结果为:{local_md5} - {remote_md5}")
  123. await write_remote_meta_md5(remote_path, remote_md5, sftp)
  124. except asyncio.TimeoutError:
  125. logger.error(f"Timeout occurred while comparing MD5 for {local_path} and {remote_path}")
  126. return
  127. async def get_remote_meta_md5(remote_path, sftp):
  128. meta_path = f"{remote_path}.meta"
  129. try:
  130. async with await sftp.open(meta_path, 'r') as meta_file:
  131. meta_content = await meta_file.read()
  132. if 'MD5:' in meta_content:
  133. remote_md5 = meta_content.split('MD5:')[1].strip()
  134. return remote_md5
  135. except asyncssh.SFTPNoSuchFile:
  136. logger.warning(f"Meta文件不存在: {os.path.basename(meta_path)}")
  137. pass
  138. return None
  139. async def write_remote_meta_md5(remote_path, remote_md5, sftp):
  140. try:
  141. meta_content = f"MD5: {remote_md5}"
  142. meta_path = f"{remote_path}.meta"
  143. async with await sftp.open(meta_path, 'w') as meta_file:
  144. await meta_file.write(meta_content)
  145. logger.info(f'Meta文件创建成功:{os.path.basename(meta_path)}')
  146. except Exception as e:
  147. logger.error(f"写入远程.meta文件失败:{remote_path}. 错误信息:{e}")
  148. async def calculate_remote_md5(sftp, remote_path):
  149. hasher = hashlib.md5()
  150. try:
  151. async with sftp.open(remote_path, 'rb') as file:
  152. while True:
  153. data = await file.read(65536) # Read the file in chunks of 64KB
  154. if not data:
  155. break
  156. hasher.update(data)
  157. except asyncssh.SFTPNoSuchFile:
  158. logger.warning(f"远程文件不存在: {os.path.basename(remote_path)}")
  159. return None
  160. return hasher.hexdigest()
  161. def calculate_md5(file_path):
  162. hasher = hashlib.md5()
  163. with open(file_path, 'rb') as file:
  164. while True:
  165. data = file.read(65536) # Read the file in chunks of 64KB
  166. if not data:
  167. break
  168. hasher.update(data)
  169. return hasher.hexdigest()
  170. # Example usage:
  171. hostname = '192.168.188.252'
  172. username = 'neozhang'
  173. password = '58920912a'
  174. local_directory = 'E:\\AI\\stable-diffusion-webui-master\\models\\Stable-diffusion'
  175. remote_directory = '/home/neozhang/work/project/stable-diffusion-webui/models/Stable-diffusion'
  176. file_extension = '*'
  177. depth = -1 # 包括多少级子目录下的文件(-1代表包括所有子目录)
  178. calculate_timeout = 1000 # 计算远程MD5的超时时间
  179. coroutine_count = 5
  180. files_to_upload = get_files_to_upload(local_directory, remote_directory, depth, file_extension)
  181. logger.info(files_to_upload)
  182. async def main():
  183. logger.info('======================================================================')
  184. async with asyncssh.connect(hostname, username=username, password=password, known_hosts=None) as conn:
  185. sftp = await conn.start_sftp_client()
  186. sem = asyncio.Semaphore(coroutine_count) # 设置最大并发协程数量为5
  187. async def limited_compare_md5(local_path, remote_path):
  188. async with sem:
  189. await compare_md5(local_path, remote_path, sftp)
  190. coroutines = [limited_compare_md5(file_info['local_path'], file_info['remote_path'])
  191. for file_info in files_to_upload]
  192. await asyncio.gather(*coroutines)
  193. # Run the main coroutine
  194. loop = asyncio.get_event_loop()
  195. loop.run_until_complete(main())