train_aw.sh 956 B

12345678910111213141516171819202122
  1. #!/bin/bash
  2. multi_gpu=0 # multi gpu | 多显卡训练 该参数仅限在显卡数 >= 2 使用
  3. config_file="./toml/aw.toml" # config_file | 使用toml文件指定训练参数
  4. sample_prompts="./toml/sample_prompts.txt" # sample_prompts | 采样prompts文件,留空则不启用采样功能
  5. utf8=1 # utf8 | 使用utf-8编码读取toml;以utf-8编码编写的、含中文的toml必须开启
  6. # ============= DO NOT MODIFY CONTENTS BELOW | 请勿修改下方内容 =====================
  7. export HF_HOME="huggingface"
  8. export TF_CPP_MIN_LOG_LEVEL=3
  9. extArgs=()
  10. launchArgs=()
  11. if [[ $multi_gpu == 1 ]]; then launchArgs+=("--multi_gpu"); fi
  12. if [[ $utf8 == 1 ]]; then export PYTHONUTF8=1; fi
  13. # run train
  14. accelerate launch ${launchArgs[@]} --num_cpu_threads_per_process=8 "./sd-scripts/train_network.py" \
  15. --config_file=$config_file \
  16. --sample_prompts=$sample_prompts \
  17. ${extArgs[@]}