lora_block_weight.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733
  1. import cv2
  2. import os
  3. import gc
  4. import re
  5. import sys
  6. import torch
  7. import shutil
  8. import math
  9. import numpy as np
  10. import gradio as gr
  11. import os.path
  12. import random
  13. from pprint import pprint
  14. import modules.ui
  15. import modules.scripts as scripts
  16. from PIL import Image, ImageFont, ImageDraw
  17. from fonts.ttf import Roboto
  18. import modules.shared as shared
  19. from modules import devices, sd_models, images,extra_networks
  20. from modules.shared import opts, state
  21. from modules.processing import process_images, Processed
  22. lxyz = ""
  23. lzyx = ""
  24. prompts = ""
  25. xyelem = ""
  26. princ = False
  27. BLOCKID=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
  28. BLOCKS=["encoder",
  29. "diffusion_model_input_blocks_0_",
  30. "diffusion_model_input_blocks_1_",
  31. "diffusion_model_input_blocks_2_",
  32. "diffusion_model_input_blocks_3_",
  33. "diffusion_model_input_blocks_4_",
  34. "diffusion_model_input_blocks_5_",
  35. "diffusion_model_input_blocks_6_",
  36. "diffusion_model_input_blocks_7_",
  37. "diffusion_model_input_blocks_8_",
  38. "diffusion_model_input_blocks_9_",
  39. "diffusion_model_input_blocks_10_",
  40. "diffusion_model_input_blocks_11_",
  41. "diffusion_model_middle_block_",
  42. "diffusion_model_output_blocks_0_",
  43. "diffusion_model_output_blocks_1_",
  44. "diffusion_model_output_blocks_2_",
  45. "diffusion_model_output_blocks_3_",
  46. "diffusion_model_output_blocks_4_",
  47. "diffusion_model_output_blocks_5_",
  48. "diffusion_model_output_blocks_6_",
  49. "diffusion_model_output_blocks_7_",
  50. "diffusion_model_output_blocks_8_",
  51. "diffusion_model_output_blocks_9_",
  52. "diffusion_model_output_blocks_10_",
  53. "diffusion_model_output_blocks_11_"]
  54. loopstopper = True
  55. ATYPES =["none","Block ID","values","seed","Original Weights","elements"]
  56. DEF_WEIGHT_PRESET = "\
  57. NONE:0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
  58. ALL:1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\n\
  59. INS:1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
  60. IND:1,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0\n\
  61. INALL:1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0\n\
  62. MIDD:1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0\n\
  63. OUTD:1,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0\n\
  64. OUTS:1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1\n\
  65. OUTALL:1,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1\n\
  66. ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
  67. class Script(modules.scripts.Script):
  68. def title(self):
  69. return "LoRA Block Weight"
  70. def show(self, is_img2img):
  71. return modules.scripts.AlwaysVisible
  72. def ui(self, is_img2img):
  73. import lora
  74. LWEIGHTSPRESETS = DEF_WEIGHT_PRESET
  75. runorigin = scripts.scripts_txt2img.run
  76. runorigini = scripts.scripts_img2img.run
  77. path_root = scripts.basedir()
  78. extpath = os.path.join(path_root,"extensions","sd-webui-lora-block-weight","scripts", "lbwpresets.txt")
  79. filepath = os.path.join(path_root,"scripts", "lbwpresets.txt")
  80. filepathe = os.path.join(path_root,"scripts", "elempresets.txt")
  81. if os.path.isfile(extpath) and not os.path.isfile(filepath):
  82. shutil.move(extpath,filepath)
  83. lbwpresets=""
  84. try:
  85. with open(filepath,encoding="utf-8") as f:
  86. lbwpresets = f.read()
  87. except OSError as e:
  88. lbwpresets=LWEIGHTSPRESETS
  89. if not os.path.isfile(filepath):
  90. try:
  91. with open(filepath,mode = 'w',encoding="utf-8") as f:
  92. f.write(lbwpresets)
  93. except:
  94. pass
  95. try:
  96. with open(filepathe,encoding="utf-8") as f:
  97. elempresets = f.read()
  98. except OSError as e:
  99. elempresets=ELEMPRESETS
  100. if not os.path.isfile(filepathe):
  101. try:
  102. with open(filepathe,mode = 'w',encoding="utf-8") as f:
  103. f.write(elempresets)
  104. except:
  105. pass
  106. loraratios=lbwpresets.splitlines()
  107. lratios={}
  108. for i,l in enumerate(loraratios):
  109. if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue
  110. lratios[l.split(":")[0]]=l.split(":")[1]
  111. ratiostags = [k for k in lratios.keys()]
  112. ratiostags = ",".join(ratiostags)
  113. with gr.Accordion("LoRA Block Weight",open = False):
  114. with gr.Row():
  115. with gr.Column(min_width = 50, scale=1):
  116. lbw_useblocks = gr.Checkbox(value = True,label="Active",interactive =True,elem_id="lbw_active")
  117. with gr.Column(scale=5):
  118. bw_ratiotags= gr.TextArea(label="",lines=2,value=ratiostags,visible =True,interactive =True,elem_id="lbw_ratios")
  119. with gr.Accordion("XYZ plot",open = False):
  120. gr.HTML(value="<p>changeable blocks : BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11</p>")
  121. xyzsetting = gr.Radio(label = "Active",choices = ["Disable","XYZ plot","Effective Block Analyzer"], value ="Disable",type = "index")
  122. with gr.Row(visible = False) as esets:
  123. diffcol = gr.Radio(label = "diff image color",choices = ["black","white"], value ="black",type = "value",interactive =True)
  124. revxy = gr.Checkbox(value = False,label="change X-Y",interactive =True,elem_id="lbw_changexy")
  125. thresh = gr.Textbox(label="difference threshold",lines=1,value="20",interactive =True,elem_id="diff_thr")
  126. xtype = gr.Dropdown(label="X Types ", choices=[x for x in ATYPES], value=ATYPES [2],interactive =True,elem_id="lbw_xtype")
  127. xmen = gr.Textbox(label="X Values ",lines=1,value="0,0.25,0.5,0.75,1",interactive =True,elem_id="lbw_xmen")
  128. ytype = gr.Dropdown(label="Y Types ", choices=[y for y in ATYPES], value=ATYPES [1],interactive =True,elem_id="lbw_ytype")
  129. ymen = gr.Textbox(label="Y Values " ,lines=1,value="IN05-OUT05",interactive =True,elem_id="lbw_ymen")
  130. ztype = gr.Dropdown(label="Z type ", choices=[z for z in ATYPES], value=ATYPES[0],interactive =True,elem_id="lbw_ztype")
  131. zmen = gr.Textbox(label="Z values ",lines=1,value="",interactive =True,elem_id="lbw_zmen")
  132. exmen = gr.Textbox(label="Range",lines=1,value="0.5,1",interactive =True,elem_id="lbw_exmen",visible = False)
  133. eymen = gr.Textbox(label="Blocks" ,lines=1,value="BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11",interactive =True,elem_id="lbw_eymen",visible = False)
  134. ecount = gr.Number(value=1, label="number of seed", interactive=True, visible = True)
  135. with gr.Accordion("Weights setting",open = True):
  136. with gr.Row():
  137. reloadtext = gr.Button(value="Reload Presets",variant='primary',elem_id="lbw_reload")
  138. reloadtags = gr.Button(value="Reload Tags",variant='primary',elem_id="lbw_reload")
  139. savetext = gr.Button(value="Save Presets",variant='primary',elem_id="lbw_savetext")
  140. openeditor = gr.Button(value="Open TextEditor",variant='primary',elem_id="lbw_openeditor")
  141. lbw_loraratios = gr.TextArea(label="",value=lbwpresets,visible =True,interactive = True,elem_id="lbw_ratiospreset")
  142. with gr.Accordion("Elemental",open = False):
  143. with gr.Row():
  144. e_reloadtext = gr.Button(value="Reload Presets",variant='primary',elem_id="lbw_reload")
  145. e_savetext = gr.Button(value="Save Presets",variant='primary',elem_id="lbw_savetext")
  146. e_openeditor = gr.Button(value="Open TextEditor",variant='primary',elem_id="lbw_openeditor")
  147. elemsets = gr.Checkbox(value = False,label="print change",interactive =True,elem_id="lbw_print_change")
  148. elemental = gr.TextArea(label="Identifer:BlockID:Elements:Ratio,...,separated by empty line ",value = elempresets,interactive =True,elem_id="element")
  149. d_true = gr.Checkbox(value = True,visible = False)
  150. d_false = gr.Checkbox(value = False,visible = False)
  151. import subprocess
  152. def openeditors(b):
  153. path = filepath if b else filepathe
  154. subprocess.Popen(['start', path], shell=True)
  155. def reloadpresets(isweight):
  156. if isweight:
  157. try:
  158. with open(filepath,encoding="utf-8") as f:
  159. return f.read()
  160. except OSError as e:
  161. pass
  162. else:
  163. try:
  164. with open(filepathe,encoding="utf-8") as f:
  165. return f.read()
  166. except OSError as e:
  167. pass
  168. def tagdicter(presets):
  169. presets=presets.splitlines()
  170. wdict={}
  171. for l in presets:
  172. if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue
  173. w=[]
  174. if ":" in l :
  175. key = l.split(":",1)[0]
  176. w = l.split(":",1)[1]
  177. if len([w for w in w.split(",")]) == 17 or len([w for w in w.split(",")]) ==26:
  178. wdict[key.strip()]=w
  179. return ",".join(list(wdict.keys()))
  180. def savepresets(text,isweight):
  181. if isweight:
  182. with open(filepath,mode = 'w',encoding="utf-8") as f:
  183. f.write(text)
  184. else:
  185. with open(filepathe,mode = 'w',encoding="utf-8") as f:
  186. f.write(text)
  187. reloadtext.click(fn=reloadpresets,inputs=[d_true],outputs=[lbw_loraratios])
  188. reloadtags.click(fn=tagdicter,inputs=[lbw_loraratios],outputs=[bw_ratiotags])
  189. savetext.click(fn=savepresets,inputs=[lbw_loraratios,d_true],outputs=[])
  190. openeditor.click(fn=openeditors,inputs=[d_true],outputs=[])
  191. e_reloadtext.click(fn=reloadpresets,inputs=[d_false],outputs=[elemental])
  192. e_savetext.click(fn=savepresets,inputs=[elemental,d_false],outputs=[])
  193. e_openeditor.click(fn=openeditors,inputs=[d_false],outputs=[])
  194. def urawaza(active):
  195. if active > 0:
  196. for obj in scripts.scripts_txt2img.alwayson_scripts:
  197. if "lora_block_weight" in obj.filename:
  198. scripts.scripts_txt2img.selectable_scripts.append(obj)
  199. scripts.scripts_txt2img.titles.append("LoRA Block Weight")
  200. for obj in scripts.scripts_img2img.alwayson_scripts:
  201. if "lora_block_weight" in obj.filename:
  202. scripts.scripts_img2img.selectable_scripts.append(obj)
  203. scripts.scripts_img2img.titles.append("LoRA Block Weight")
  204. scripts.scripts_txt2img.run = newrun
  205. scripts.scripts_img2img.run = newrun
  206. if active == 1:return [*[gr.update(visible = True) for x in range(6)],*[gr.update(visible = False) for x in range(4)]]
  207. else:return [*[gr.update(visible = False) for x in range(6)],*[gr.update(visible = True) for x in range(4)]]
  208. else:
  209. scripts.scripts_txt2img.run = runorigin
  210. scripts.scripts_img2img.run = runorigini
  211. return [*[gr.update(visible = True) for x in range(6)],*[gr.update(visible = False) for x in range(4)]]
  212. xyzsetting.change(fn=urawaza,inputs=[xyzsetting],outputs =[xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,esets])
  213. return lbw_loraratios,lbw_useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets
  214. def process(self, p, loraratios,useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets):
  215. #print("self =",self,"p =",p,"presets =",loraratios,"useblocks =",useblocks,"xyzsettings =",xyzsetting,"xtype =",xtype,"xmen =",xmen,"ytype =",ytype,"ymen =",ymen,"ztype =",ztype,"zmen =",zmen)
  216. #Note that this does not use the default arg syntax because the default args are supposed to be at the end of the function
  217. if(loraratios == None):
  218. loraratios = DEF_WEIGHT_PRESET
  219. if(useblocks == None):
  220. useblocks = True
  221. if useblocks:
  222. loraratios=loraratios.splitlines()
  223. elemental = elemental.split("\n\n")
  224. lratios={}
  225. elementals={}
  226. for l in loraratios:
  227. if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue
  228. l0=l.split(":",1)[0]
  229. lratios[l0.strip()]=l.split(":",1)[1]
  230. for e in elemental:
  231. e0=e.split(":",1)[0]
  232. elementals[e0.strip()]=e.split(":",1)[1]
  233. if elemsets : print(xyelem)
  234. if xyzsetting and "XYZ" in p.prompt:
  235. lratios["XYZ"] = lxyz
  236. lratios["ZYX"] = lzyx
  237. if xyelem != "":
  238. if "XYZ" in elementals.keys():
  239. elementals["XYZ"] = elementals["XYZ"] + ","+ xyelem
  240. else:
  241. elementals["XYZ"] = xyelem
  242. self.lratios = lratios
  243. self.elementals = elementals
  244. global princ
  245. princ = elemsets
  246. return
  247. def before_process_batch(self, p, loraratios,useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets,**kwargs):
  248. if useblocks:
  249. global prompts
  250. prompts = kwargs["prompts"].copy()
  251. def process_batch(self, p, loraratios,useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets,**kwargs):
  252. if useblocks:
  253. o_prompts = [p.prompt]
  254. for prompt in prompts:
  255. if "<lora" in prompt or "<lyco" in prompt:
  256. o_prompts = prompts.copy()
  257. loradealer(o_prompts ,self.lratios,self.elementals)
  258. def postprocess(self, p, processed, *args):
  259. import lora
  260. lora.loaded_loras.clear()
  261. global lxyz,lzyx,xyelem
  262. lxyz = lzyx = xyelem = ""
  263. gc.collect()
  264. def run(self,p,presets,useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets):
  265. if xyzsetting >0:
  266. import lora
  267. loraratios=presets.splitlines()
  268. lratios={}
  269. for l in loraratios:
  270. if ":" not in l or not (l.count(",") == 16 or l.count(",") == 25) : continue
  271. l0=l.split(":",1)[0]
  272. lratios[l0.strip()]=l.split(":",1)[1]
  273. if "XYZ" in p.prompt:
  274. base = lratios["XYZ"] if "XYZ" in lratios.keys() else "1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1"
  275. else: return
  276. if xyzsetting > 1:
  277. xmen,ymen = exmen,eymen
  278. xtype,ytype = "values","ID"
  279. ebase = xmen.split(",")[1]
  280. ebase = [ebase.strip()]*26
  281. base = ",".join(ebase)
  282. ztype = ""
  283. if ecount > 1:
  284. ztype = "seed"
  285. zmen = ",".join([str(random.randrange(4294967294)) for x in range(int(ecount))])
  286. #ATYPES =["none","Block ID","values","seed","Base Weights"]
  287. def dicedealer(am):
  288. for i,a in enumerate(am):
  289. if a =="-1": am[i] = str(random.randrange(4294967294))
  290. print(f"the die was thrown : {am}")
  291. if p.seed == -1: p.seed = str(random.randrange(4294967294))
  292. #print(f"xs:{xmen},ys:{ymen},zs:{zmen}")
  293. def adjuster(a,at):
  294. if "none" in at:a = ""
  295. a = [a.strip() for a in a.split(',')]
  296. if "seed" in at:dicedealer(a)
  297. return a
  298. xs = adjuster(xmen,xtype)
  299. ys = adjuster(ymen,ytype)
  300. zs = adjuster(zmen,ztype)
  301. ids = alpha =seed = ""
  302. p.batch_size = 1
  303. print(f"xs:{xs},ys:{ys},zs:{zs}")
  304. images = []
  305. def weightsdealer(alpha,ids,base):
  306. blockid17=["BASE","IN01","IN02","IN04","IN05","IN07","IN08","M00","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
  307. blockid26=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
  308. #print(f"weights from : {base}")
  309. ids = [z.strip() for z in ids.split(' ')]
  310. weights_t = [w.strip() for w in base.split(',')]
  311. blockid = blockid17 if len(weights_t) ==17 else blockid26
  312. if ids[0]!="NOT":
  313. flagger=[False]*len(weights_t)
  314. changer = True
  315. else:
  316. flagger=[True]*len(weights_t)
  317. changer = False
  318. for id in ids:
  319. if id =="NOT":continue
  320. if "-" in id:
  321. it = [it.strip() for it in id.split('-')]
  322. if blockid.index(it[1]) > blockid.index(it[0]):
  323. flagger[blockid.index(it[0]):blockid.index(it[1])+1] = [changer]*(blockid.index(it[1])-blockid.index(it[0])+1)
  324. else:
  325. flagger[blockid.index(it[1]):blockid.index(it[0])+1] = [changer]*(blockid.index(it[0])-blockid.index(it[1])+1)
  326. else:
  327. flagger[blockid.index(id)] =changer
  328. for i,f in enumerate(flagger):
  329. if f:weights_t[i]=alpha
  330. outext = ",".join(weights_t)
  331. #print(f"weights changed: {outext}")
  332. return outext
  333. def xyzdealer(a,at):
  334. nonlocal ids,alpha,p,base,c_base
  335. if "ID" in at:return
  336. if "values" in at:alpha = a
  337. if "seed" in at:
  338. p.seed = int(a)
  339. if "Weights" in at:base =c_base = lratios[a]
  340. if "elements" in at:
  341. global xyelem
  342. xyelem = a
  343. grids = []
  344. images =[]
  345. totalcount = len(xs)*len(ys)*len(zs) if xyzsetting < 2 else len(xs)*len(ys)*len(zs) //2 +1
  346. shared.total_tqdm.updateTotal(totalcount)
  347. xc = yc =zc = 0
  348. state.job_count = totalcount
  349. totalcount = len(xs)*len(ys)*len(zs)
  350. c_base = base
  351. for z in zs:
  352. images = []
  353. yc = 0
  354. xyzdealer(z,ztype)
  355. for y in ys:
  356. xc = 0
  357. xyzdealer(y,ytype)
  358. for x in xs:
  359. xyzdealer(x,xtype)
  360. if "ID" in xtype:
  361. if "values" in ytype:c_base = weightsdealer(y,x,base)
  362. if "values" in ztype:c_base = weightsdealer(z,x,base)
  363. if "ID" in ytype:
  364. if "values" in xtype:c_base = weightsdealer(x,y,base)
  365. if "values" in ztype:c_base = weightsdealer(z,y,base)
  366. if "ID" in ztype:
  367. if "values" in xtype:c_base = weightsdealer(x,z,base)
  368. if "values" in ytype:c_base = weightsdealer(y,z,base)
  369. print(f"X:{xtype}, {x},Y: {ytype},{y}, Z:{ztype},{z}, base:{c_base} ({len(xs)*len(ys)*zc + yc*len(xs) +xc +1}/{totalcount})")
  370. global lxyz,lzyx
  371. lxyz = c_base
  372. cr_base = c_base.split(",")
  373. cr_base_t=[]
  374. for x in cr_base:
  375. if not identifier(x):
  376. cr_base_t.append(str(1-float(x)))
  377. else:
  378. cr_base_t.append(x)
  379. lzyx = ",".join(cr_base_t)
  380. if not(xc == 1 and not (yc ==0 ) and xyzsetting >1):
  381. lora.loaded_loras.clear()
  382. processed:Processed = process_images(p)
  383. images.append(processed.images[0])
  384. xc += 1
  385. yc += 1
  386. zc += 1
  387. origin = loranames(processed.all_prompts) + ", "+ znamer(ztype,z,base)
  388. images,xst,yst = effectivechecker(images,xs.copy(),ys.copy(),diffcol,thresh,revxy) if xyzsetting >1 else (images,xs.copy(),ys.copy())
  389. grids.append(smakegrid(images,xst,yst,origin,p))
  390. processed.images= grids
  391. lora.loaded_loras.clear()
  392. return processed
  393. def identifier(char):
  394. return char[0] in ["R", "U", "X"]
  395. def znamer(at,a,base):
  396. if "ID" in at:return f"Block : {a}"
  397. if "values" in at:return f"value : {a}"
  398. if "seed" in at:return f"seed : {a}"
  399. if "Weights" in at:return f"original weights :\n {base}"
  400. else: return ""
  401. def loranames(all_prompts):
  402. _, extra_network_data = extra_networks.parse_prompts(all_prompts[0:1])
  403. calledloras = extra_network_data["lora"] if "lyco" not in extra_network_data.keys() else extra_network_data["lyco"]
  404. names = ""
  405. for called in calledloras:
  406. if len(called.items) <3:continue
  407. names += called.items[0]
  408. return names
  409. def loradealer(prompts,lratios,elementals):
  410. _, extra_network_data = extra_networks.parse_prompts(prompts)
  411. moduletypes = extra_network_data.keys()
  412. for ltype in moduletypes:
  413. lorans = []
  414. lorars = []
  415. multipliers = []
  416. elements = []
  417. if not (ltype == "lora" or ltype == "lyco") : continue
  418. for called in extra_network_data[ltype]:
  419. if ltype == "lyco":
  420. if len(called.items) > 4 : called.items[2] = called.items[4]
  421. if len(called.items) > 5 : called.items[3] = called.items[5]
  422. if len(called.items) > 4 : called.items = called.items[0:4]
  423. multiple = float(called.items[1])
  424. multipliers.append(multiple)
  425. if len(called.items) <3:
  426. continue
  427. lorans.append(called.items[0])
  428. if called.items[2] in lratios or called.items[2].count(",") ==16 or called.items[2].count(",") ==25:
  429. wei = lratios[called.items[2]] if called.items[2] in lratios else called.items[2]
  430. ratios = [w.strip() for w in wei.split(",")]
  431. for i,r in enumerate(ratios):
  432. if r =="R":
  433. ratios[i] = round(random.random(),3)
  434. elif r == "U":
  435. ratios[i] = round(random.uniform(-0.5,1.5),3)
  436. elif r[0] == "X":
  437. base = called.items[3] if len(called.items) >= 4 else 1
  438. ratios[i] = getinheritedweight(base, r)
  439. else:
  440. ratios[i] = float(r)
  441. print(f"LoRA Block weight ({ltype}): {called.items[0]}: {multiple} x {[x for x in ratios]}")
  442. if len(ratios)==17:
  443. ratios = [ratios[0]] + [1] + ratios[1:3]+ [1] + ratios[3:5]+[1] + ratios[5:7]+[1,1,1] + [ratios[7]] + [1,1,1] + ratios[8:]
  444. lorars.append(ratios)
  445. if len(called.items) > 3:
  446. if called.items[3] in elementals:
  447. elements.append(elementals[called.items[3]])
  448. else:
  449. elements.append(called.items[3])
  450. else:
  451. elements.append("")
  452. if len(lorars) > 0: load_loras_blocks(lorans,lorars,multipliers,elements,ltype)
  453. def isfloat(t):
  454. try:
  455. float(t)
  456. return True
  457. except:
  458. return False
  459. re_inherited_weight = re.compile(r"X([+-])?([\d.]+)?")
  460. def getinheritedweight(weight, offset):
  461. match = re_inherited_weight.search(offset)
  462. if match.group(1) == "+":
  463. return float(weight) + float(match.group(2))
  464. elif match.group(1) == "-":
  465. return float(weight) - float(match.group(2))
  466. else:
  467. return float(weight)
  468. def load_loras_blocks(names, lwei,multipliers,elements = [],ltype = "lora"):
  469. if "lora" == ltype:
  470. print(names,lwei,elements)
  471. import lora
  472. for l, loaded in enumerate(lora.loaded_loras):
  473. for n, name in enumerate(names):
  474. if name == loaded.name:
  475. lbw(lora.loaded_loras[l],lwei[n],elements[n])
  476. lora.loaded_loras[l].name = lora.loaded_loras[l].name +"added_by_lora_block_weight"+ str(random.random())
  477. elif "lyco" == ltype:
  478. import lycoris as lycomo
  479. for l, loaded in enumerate(lycomo.loaded_lycos):
  480. for n, name in enumerate(names):
  481. if name == loaded.name:
  482. lbw(lycomo.loaded_lycos[l],lwei[n],elements[n])
  483. lycomo.loaded_lycos[l].name = lycomo.loaded_lycos[l].name +"added_by_lora_block_weight"+ str(random.random())
  484. def smakegrid(imgs,xs,ys,currentmodel,p):
  485. ver_texts = [[images.GridAnnotation(y)] for y in ys]
  486. hor_texts = [[images.GridAnnotation(x)] for x in xs]
  487. w, h = imgs[0].size
  488. grid = Image.new('RGB', size=(len(xs) * w, len(ys) * h), color='black')
  489. for i, img in enumerate(imgs):
  490. grid.paste(img, box=(i % len(xs) * w, i // len(xs) * h))
  491. grid = images.draw_grid_annotations(grid,w, h, hor_texts, ver_texts)
  492. grid = draw_origin(grid, currentmodel,w*len(xs),h*len(ys),w)
  493. if opts.grid_save:
  494. images.save_image(grid, opts.outdir_txt2img_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=p.seed, grid=True, p=p)
  495. return grid
  496. def draw_origin(grid, text,width,height,width_one):
  497. grid_d= Image.new("RGB", (grid.width,grid.height), "white")
  498. grid_d.paste(grid,(0,0))
  499. def get_font(fontsize):
  500. try:
  501. return ImageFont.truetype(opts.font or Roboto, fontsize)
  502. except Exception:
  503. return ImageFont.truetype(Roboto, fontsize)
  504. d= ImageDraw.Draw(grid_d)
  505. color_active = (0, 0, 0)
  506. fontsize = (width+height)//25
  507. fnt = get_font(fontsize)
  508. if grid.width != width_one:
  509. while d.multiline_textsize(text, font=fnt)[0] > width_one*0.75 and fontsize > 0:
  510. fontsize -=1
  511. fnt = get_font(fontsize)
  512. d.multiline_text((0,0), text, font=fnt, fill=color_active,align="center")
  513. return grid_d
  514. def newrun(p, *args):
  515. script_index = args[0]
  516. if args[0] ==0:
  517. script = None
  518. for obj in scripts.scripts_txt2img.alwayson_scripts:
  519. if "lora_block_weight" in obj.filename:
  520. script = obj
  521. script_args = args[script.args_from:script.args_to]
  522. else:
  523. script = scripts.scripts_txt2img.selectable_scripts[script_index-1]
  524. if script is None:
  525. return None
  526. script_args = args[script.args_from:script.args_to]
  527. processed = script.run(p, *script_args)
  528. shared.total_tqdm.clear()
  529. return processed
  530. def effectivechecker(imgs,ss,ls,diffcol,thresh,revxy):
  531. diffs = []
  532. outnum =[]
  533. imgs[0],imgs[1] = imgs[1],imgs[0]
  534. im1 = np.array(imgs[0])
  535. for i in range(len(imgs)-1):
  536. im2 = np.array(imgs[i+1])
  537. abs_diff = cv2.absdiff(im2 , im1)
  538. abs_diff_t = cv2.threshold(abs_diff, int(thresh), 255, cv2.THRESH_BINARY)[1]
  539. res = abs_diff_t.astype(np.uint8)
  540. percentage = (np.count_nonzero(res) * 100)/ res.size
  541. if "white" in diffcol: abs_diff = cv2.bitwise_not(abs_diff)
  542. outnum.append(percentage)
  543. abs_diff = Image.fromarray(abs_diff)
  544. diffs.append(abs_diff)
  545. outs = []
  546. for i in range(len(ls)):
  547. ls[i] = ls[i] + "\n Diff : " + str(round(outnum[i],3)) + "%"
  548. if not revxy:
  549. for diff,img in zip(diffs,imgs[1:]):
  550. outs.append(diff)
  551. outs.append(img)
  552. outs.append(imgs[0])
  553. ss = ["diff",ss[0],"source"]
  554. return outs,ss,ls
  555. else:
  556. outs = [imgs[0]]*len(diffs) + imgs[1:]+ diffs
  557. ss = ["source",ss[0],"diff"]
  558. return outs,ls,ss
  559. def lbw(lora,lwei,elemental):
  560. elemental = elemental.split(",")
  561. for key in lora.modules.keys():
  562. ratio = 1
  563. picked = False
  564. errormodules = []
  565. for i,block in enumerate(BLOCKS):
  566. if block in key:
  567. ratio = lwei[i]
  568. picked = True
  569. currentblock = i
  570. if not picked:
  571. errormodules.append(key)
  572. if len(elemental) > 0:
  573. skey = key + BLOCKID[currentblock]
  574. for d in elemental:
  575. if d.count(":") != 2 :continue
  576. dbs,dws,dr = (hyphener(d.split(":")[0]),d.split(":")[1],d.split(":")[2])
  577. dbs,dws = (dbs.split(" "), dws.split(" "))
  578. dbn,dbs = (True,dbs[1:]) if dbs[0] == "NOT" else (False,dbs)
  579. dwn,dws = (True,dws[1:]) if dws[0] == "NOT" else (False,dws)
  580. flag = dbn
  581. for db in dbs:
  582. if db in skey:
  583. flag = not dbn
  584. if flag:flag = dwn
  585. else:continue
  586. for dw in dws:
  587. if dw in skey:
  588. flag = not dwn
  589. if flag:
  590. dr = float(dr)
  591. if princ :print(dbs,dws,key,dr)
  592. ratio = dr
  593. ltype = type(lora.modules[key]).__name__
  594. set = False
  595. if ltype in LORAANDSOON.keys():
  596. setattr(lora.modules[key],LORAANDSOON[ltype],torch.nn.Parameter(getattr(lora.modules[key],LORAANDSOON[ltype]) * ratio))
  597. #print(ltype)
  598. set = True
  599. else:
  600. if hasattr(lora.modules[key],"up_model"):
  601. lora.modules[key].up_model.weight= torch.nn.Parameter(lora.modules[key].up_model.weight *ratio)
  602. #print("LoRA using LoCON")
  603. set = True
  604. else:
  605. lora.modules[key].up.weight= torch.nn.Parameter(lora.modules[key].up.weight *ratio)
  606. #print("LoRA")
  607. set = True
  608. if not set :
  609. print("unkwon LoRA")
  610. lora.name = lora.name +"added_by_lora_block_weight"+ str(random.random())
  611. if len(errormodules) > 0:
  612. print(errormodules)
  613. return lora
  614. LORAANDSOON = {
  615. "LoraHadaModule" : "w1a",
  616. "LycoHadaModule" : "w1a",
  617. "FullModule" : "weight",
  618. "IA3Module" : "w",
  619. "LoraKronModule" : "w1",
  620. "LycoKronModule" : "w1",
  621. }
  622. def hyphener(t):
  623. t = t.split(" ")
  624. for i,e in enumerate(t):
  625. if "-" in e:
  626. e = e.split("-")
  627. if BLOCKID.index(e[1]) > BLOCKID.index(e[0]):
  628. t[i] = " ".join(BLOCKID[BLOCKID.index(e[0]):BLOCKID.index(e[1])+1])
  629. else:
  630. t[i] = " ".join(BLOCKID[BLOCKID.index(e[1]):BLOCKID.index(e[0])+1])
  631. return " ".join(t)
  632. ELEMPRESETS="\
  633. ATTNDEEPON:IN05-OUT05:attn:1\n\n\
  634. ATTNDEEPOFF:IN05-OUT05:attn:0\n\n\
  635. PROJDEEPOFF:IN05-OUT05:proj:0\n\n\
  636. XYZ:::1"