诸葛温侯 1 年間 前
コミット
1d600d1e80
2 ファイル変更294 行追加0 行削除
  1. 46 0
      file.py
  2. 248 0
      main.py

+ 46 - 0
file.py

@@ -0,0 +1,46 @@
+import os
+
+
+def get_files_to_upload(local_directory, remote_directory, depth, file_extension='*'):
+    """
+    获取目录下文件并转化为目标目录地址
+    :param local_directory: 本地目录
+    :param remote_directory: 远程目录
+    :param depth: 读取深度,-1为全部
+    :param file_extension: 文件后缀
+    :return:
+    """
+    files_to_upload = []
+    for root, dirs, files in os.walk(local_directory):
+        if depth >= 0:
+            # 计算当前目录的深度
+            current_depth = root[len(local_directory) + len(os.sep):].count(os.sep)
+            if current_depth > depth:
+                continue
+        destination_dir = os.path.join(remote_directory, os.path.relpath(root, local_directory))
+        for file in files:
+            if file_extension == '*' or file.endswith(file_extension):
+                local_path = os.path.join(root, file)  # 本地文件路径
+                remote_path = os.path.join(destination_dir, file)  # 远程文件路径
+                remote_path = remote_path.replace('\\','/')
+                files_to_upload.append({
+                    'local_path': local_path,
+                    'remote_path': remote_path
+                })
+    return files_to_upload
+
+
+def example():
+    local_directory = "D:\\Download\\downloads"  # 本地目录路径
+    remote_directory = "/home/ai/work/project/stable-diffusion-webui/extensions/sd-webui-controlnet/downloads"  # 远程目录路径
+    depth = -1  # 包括多少级子目录下的文件(-1代表包括所有子目录)
+
+    files_to_upload = get_files_to_upload(local_directory, remote_directory, depth)
+    print(files_to_upload)
+
+
+# Example usage:
+if __name__ == '__main__':
+    example()
+
+

+ 248 - 0
main.py

@@ -0,0 +1,248 @@
+import os
+import sys
+import time
+import hashlib
+import asyncssh
+import asyncio
+from tqdm import tqdm
+from loguru import logger
+from file import get_files_to_upload
+
+
+# 配置日志输出到文件和控制台
+logger.add("output.log")
+logger.add(sys.stderr, level="WARNING")
+
+
+# 上传文件的函数
+async def upload_file(local_path, local_md5, remote_path, sftp):
+    file_size = os.path.getsize(local_path)
+
+    async with await sftp.open(remote_path, 'wb') as file:
+        bytes_uploaded = 0
+
+        with open(local_path, 'rb') as local_file:
+            with tqdm(total=file_size, unit='B', unit_scale=True, ncols=80,
+                      desc=os.path.basename(local_path)) as progress_bar:
+                while True:
+                    chunk = local_file.read(65536)  # 读取文件块
+
+                    if not chunk:
+                        break
+
+                    await file.write(chunk)  # 写入文件块
+                    bytes_uploaded += len(chunk)
+                    progress_bar.update(len(chunk))
+
+        if bytes_uploaded == file_size:
+            remote_md5 = await asyncio.wait_for(calculate_remote_md5(sftp, remote_path), timeout=calculate_timeout)  # 设置超时时间为5秒
+
+            if remote_md5 == local_md5:
+                await write_remote_meta_md5(remote_path, remote_md5, sftp)
+                logger.info("文件上传成功")
+            else:
+                logger.info("本地与远程MD5不一致,请检查上传过程中是否发生错误!")
+        else:
+            logger.info("文件上传失败,请检查文件后再试验。")
+
+
+# 递归创建远程目录
+async def create_remote_directory_recursive(directory, sftp):
+    parent_dir = os.path.dirname(directory)
+
+    try:
+        await sftp.stat(parent_dir)
+    except asyncssh.SFTPError as e:
+        if "No such file" in str(e):
+            await create_remote_directory_recursive(parent_dir, sftp)
+        else:
+            logger.error(f"检查远程目录失败: {parent_dir}. 错误信息: {e}")
+            return
+
+    try:
+        await sftp.mkdir(directory)
+    except asyncssh.SFTPError as e:
+        logger.error(f"创建远程目录失败: {directory}. 错误信息: {e}")
+
+
+# 创建一个锁对象
+lock = asyncio.Lock()
+
+
+# 创建目录并上传文件的函数
+async def create_directory_and_upload(local_path, local_md5, remote_path, sftp):
+    directory = os.path.dirname(remote_path)
+
+    # 检查远程目录是否存在,如果不存在则创建
+    try:
+        await lock.acquire()  # 获取锁对象
+        attempts = 5  # 尝试次数
+        while attempts > 0:
+            try:
+                await sftp.stat(directory)
+                break  # 目录已存在,跳出循环
+            except asyncssh.SFTPError as e:
+                if "No such file" in str(e):
+                    await create_remote_directory_recursive(directory, sftp)
+                    break  # 目录创建成功,跳出循环
+                elif "Failure" in str(e):  # 这里可以根据实际情况修改判断条件
+                    await asyncio.sleep(5)  # 等待5秒
+                    attempts -= 1
+                else:
+                    logger.error(f"检查远程目录失败: {directory}. 错误信息: {e}")
+                    return
+    finally:
+        lock.release()
+
+    # 上传文件
+    for attempt in range(3):  # 最多尝试3次上传操作
+        try:
+            await upload_file(local_path, local_md5, remote_path, sftp)
+            break  # 上传成功,退出循环
+        except Exception as e:
+            logger.error(f"上传文件失败: {os.path.basename(local_path)}。错误信息: {e}")
+            if attempt < 2:
+                logger.info("重新尝试上传文件...")
+            else:
+                logger.info("多次上传文件失败.")
+                return
+
+
+async def compare_md5(local_path, remote_path, sftp):
+    logger.info(f"正在比对文件:{os.path.basename(local_path)}中")
+    try:
+        local_md5 = calculate_md5(local_path)
+        # 获取远程meta文件中的md5
+        remote_meta_md5 = await get_remote_meta_md5(remote_path, sftp)
+        if remote_meta_md5 and local_md5 == remote_meta_md5:
+            logger.info("Meta文件校验成功 - 跳过远程md5计算。")
+            return
+
+        # 计算远程MD5码,并设置较长的超时时间
+        remote_md5 = await asyncio.wait_for(calculate_remote_md5(sftp, remote_path), timeout=calculate_timeout)  # 设置超时时间为5秒
+
+        if remote_md5 is None:
+            # 执行上传操作
+            logger.info(f'准备文件上传 - 文件名为:{os.path.basename(local_path)}')
+            await create_directory_and_upload(local_path, local_md5, remote_path, sftp)
+
+        elif local_md5 != remote_md5:
+            logger.info(f"MD5校验失败 - 校验结果为:{local_md5} - {remote_md5}")
+            # 删除远程文件
+            for attempt in range(3):  # 最多尝试3次删除操作
+                try:
+                    await sftp.remove(remote_path)
+                    logger.info(f'删除远程文件成功:{remote_path}')
+                    break  # 删除成功,退出循环
+                except Exception as e:
+                    logger.error(f"删除远程文件失败: {remote_path}. 错误信息: {e}")
+                    if attempt < 2:
+                        logger.info("重新尝试删除远程文件...")
+                    else:
+                        return
+
+            # 执行上传操作
+            logger.info(f'准备文件上传 - 文件名为:{os.path.basename(local_path)}')
+            await create_directory_and_upload(local_path, local_md5, remote_path, sftp)
+
+        else:
+            logger.info(f"MD5码校验成功 - 校验结果为:{local_md5} - {remote_md5}")
+            await write_remote_meta_md5(remote_path, remote_md5, sftp)
+
+    except asyncio.TimeoutError:
+        logger.error(f"Timeout occurred while comparing MD5 for {local_path} and {remote_path}")
+        return
+
+
+async def get_remote_meta_md5(remote_path, sftp):
+    meta_path = f"{remote_path}.meta"
+
+    try:
+        async with await sftp.open(meta_path, 'r') as meta_file:
+            meta_content = await meta_file.read()
+
+            if 'MD5:' in meta_content:
+                remote_md5 = meta_content.split('MD5:')[1].strip()
+                return remote_md5
+
+    except asyncssh.SFTPNoSuchFile:
+        logger.warning(f"Meta文件不存在: {os.path.basename(meta_path)}")
+        pass
+
+    return None
+
+
+async def write_remote_meta_md5(remote_path, remote_md5, sftp):
+    try:
+        meta_content = f"MD5: {remote_md5}"
+        meta_path = f"{remote_path}.meta"
+
+        async with await sftp.open(meta_path, 'w') as meta_file:
+            await meta_file.write(meta_content)
+            logger.info(f'Meta文件创建成功:{os.path.basename(meta_path)}')
+
+    except Exception as e:
+        logger.error(f"写入远程.meta文件失败:{remote_path}. 错误信息:{e}")
+
+
+async def calculate_remote_md5(sftp, remote_path):
+    hasher = hashlib.md5()
+    try:
+        async with sftp.open(remote_path, 'rb') as file:
+            while True:
+                data = await file.read(65536)  # Read the file in chunks of 64KB
+                if not data:
+                    break
+                hasher.update(data)
+    except asyncssh.SFTPNoSuchFile:
+        logger.warning(f"远程文件不存在: {os.path.basename(remote_path)}")
+        return None
+    return hasher.hexdigest()
+
+
+def calculate_md5(file_path):
+    hasher = hashlib.md5()
+    with open(file_path, 'rb') as file:
+        while True:
+            data = file.read(65536)  # Read the file in chunks of 64KB
+            if not data:
+                break
+            hasher.update(data)
+    return hasher.hexdigest()
+
+
+# Example usage:
+hostname = '192.168.188.252'
+username = 'neozhang'
+password = '58920912a'
+local_directory = 'E:\\AI\\stable-diffusion-webui-master\\models\\Stable-diffusion'
+remote_directory = '/home/neozhang/work/project/stable-diffusion-webui/models/Stable-diffusion'
+file_extension = '*'
+depth = -1  # 包括多少级子目录下的文件(-1代表包括所有子目录)
+calculate_timeout = 1000 # 计算远程MD5的超时时间
+coroutine_count = 5
+
+
+files_to_upload = get_files_to_upload(local_directory, remote_directory, depth, file_extension)
+
+logger.info(files_to_upload)
+
+
+async def main():
+    logger.info('======================================================================')
+    async with asyncssh.connect(hostname, username=username, password=password, known_hosts=None) as conn:
+        sftp = await conn.start_sftp_client()
+
+        sem = asyncio.Semaphore(coroutine_count)  # 设置最大并发协程数量为5
+
+        async def limited_compare_md5(local_path, remote_path):
+            async with sem:
+                await compare_md5(local_path, remote_path, sftp)
+
+        coroutines = [limited_compare_md5(file_info['local_path'], file_info['remote_path'])
+                      for file_info in files_to_upload]
+        await asyncio.gather(*coroutines)
+
+# Run the main coroutine
+loop = asyncio.get_event_loop()
+loop.run_until_complete(main())