gen_img_diffusers.py 147 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311
  1. """
  2. VGG(
  3. (features): Sequential(
  4. (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  5. (1): ReLU(inplace=True)
  6. (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  7. (3): ReLU(inplace=True)
  8. (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  9. (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  10. (6): ReLU(inplace=True)
  11. (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  12. (8): ReLU(inplace=True)
  13. (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  14. (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  15. (11): ReLU(inplace=True)
  16. (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  17. (13): ReLU(inplace=True)
  18. (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  19. (15): ReLU(inplace=True)
  20. (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  21. (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  22. (18): ReLU(inplace=True)
  23. (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  24. (20): ReLU(inplace=True)
  25. (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  26. (22): ReLU(inplace=True)
  27. (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  28. (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  29. (25): ReLU(inplace=True)
  30. (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  31. (27): ReLU(inplace=True)
  32. (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  33. (29): ReLU(inplace=True)
  34. (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  35. )
  36. (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  37. (classifier): Sequential(
  38. (0): Linear(in_features=25088, out_features=4096, bias=True)
  39. (1): ReLU(inplace=True)
  40. (2): Dropout(p=0.5, inplace=False)
  41. (3): Linear(in_features=4096, out_features=4096, bias=True)
  42. (4): ReLU(inplace=True)
  43. (5): Dropout(p=0.5, inplace=False)
  44. (6): Linear(in_features=4096, out_features=1000, bias=True)
  45. )
  46. )
  47. """
  48. import json
  49. from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
  50. import glob
  51. import importlib
  52. import inspect
  53. import time
  54. import zipfile
  55. from diffusers.utils import deprecate
  56. from diffusers.configuration_utils import FrozenDict
  57. import argparse
  58. import math
  59. import os
  60. import random
  61. import re
  62. import diffusers
  63. import numpy as np
  64. import torch
  65. import torchvision
  66. from diffusers import (
  67. AutoencoderKL,
  68. DDPMScheduler,
  69. EulerAncestralDiscreteScheduler,
  70. DPMSolverMultistepScheduler,
  71. DPMSolverSinglestepScheduler,
  72. LMSDiscreteScheduler,
  73. PNDMScheduler,
  74. DDIMScheduler,
  75. EulerDiscreteScheduler,
  76. HeunDiscreteScheduler,
  77. KDPM2DiscreteScheduler,
  78. KDPM2AncestralDiscreteScheduler,
  79. UNet2DConditionModel,
  80. StableDiffusionPipeline,
  81. )
  82. from einops import rearrange
  83. from torch import einsum
  84. from tqdm import tqdm
  85. from torchvision import transforms
  86. from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig
  87. import PIL
  88. from PIL import Image
  89. from PIL.PngImagePlugin import PngInfo
  90. import library.model_util as model_util
  91. import library.train_util as train_util
  92. from networks.lora import LoRANetwork
  93. import tools.original_control_net as original_control_net
  94. from tools.original_control_net import ControlNetInfo
  95. from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
  96. # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
  97. TOKENIZER_PATH = "openai/clip-vit-large-patch14"
  98. V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
  99. DEFAULT_TOKEN_LENGTH = 75
  100. # scheduler:
  101. SCHEDULER_LINEAR_START = 0.00085
  102. SCHEDULER_LINEAR_END = 0.0120
  103. SCHEDULER_TIMESTEPS = 1000
  104. SCHEDLER_SCHEDULE = "scaled_linear"
  105. # その他の設定
  106. LATENT_CHANNELS = 4
  107. DOWNSAMPLING_FACTOR = 8
  108. # CLIP_ID_L14_336 = "openai/clip-vit-large-patch14-336"
  109. # CLIP guided SD関連
  110. CLIP_MODEL_PATH = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
  111. FEATURE_EXTRACTOR_SIZE = (224, 224)
  112. FEATURE_EXTRACTOR_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073]
  113. FEATURE_EXTRACTOR_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711]
  114. VGG16_IMAGE_MEAN = [0.485, 0.456, 0.406]
  115. VGG16_IMAGE_STD = [0.229, 0.224, 0.225]
  116. VGG16_INPUT_RESIZE_DIV = 4
  117. # CLIP特徴量の取得時にcutoutを使うか:使う場合にはソースを書き換えてください
  118. NUM_CUTOUTS = 4
  119. USE_CUTOUTS = False
  120. # region モジュール入れ替え部
  121. """
  122. 高速化のためのモジュール入れ替え
  123. """
  124. # FlashAttentionを使うCrossAttention
  125. # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
  126. # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
  127. # constants
  128. EPSILON = 1e-6
  129. # helper functions
  130. def exists(val):
  131. return val is not None
  132. def default(val, d):
  133. return val if exists(val) else d
  134. # flash attention forwards and backwards
  135. # https://arxiv.org/abs/2205.14135
  136. class FlashAttentionFunction(torch.autograd.Function):
  137. @staticmethod
  138. @torch.no_grad()
  139. def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
  140. """Algorithm 2 in the paper"""
  141. device = q.device
  142. dtype = q.dtype
  143. max_neg_value = -torch.finfo(q.dtype).max
  144. qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
  145. o = torch.zeros_like(q)
  146. all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
  147. all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
  148. scale = q.shape[-1] ** -0.5
  149. if not exists(mask):
  150. mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
  151. else:
  152. mask = rearrange(mask, "b n -> b 1 1 n")
  153. mask = mask.split(q_bucket_size, dim=-1)
  154. row_splits = zip(
  155. q.split(q_bucket_size, dim=-2),
  156. o.split(q_bucket_size, dim=-2),
  157. mask,
  158. all_row_sums.split(q_bucket_size, dim=-2),
  159. all_row_maxes.split(q_bucket_size, dim=-2),
  160. )
  161. for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
  162. q_start_index = ind * q_bucket_size - qk_len_diff
  163. col_splits = zip(
  164. k.split(k_bucket_size, dim=-2),
  165. v.split(k_bucket_size, dim=-2),
  166. )
  167. for k_ind, (kc, vc) in enumerate(col_splits):
  168. k_start_index = k_ind * k_bucket_size
  169. attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
  170. if exists(row_mask):
  171. attn_weights.masked_fill_(~row_mask, max_neg_value)
  172. if causal and q_start_index < (k_start_index + k_bucket_size - 1):
  173. causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
  174. q_start_index - k_start_index + 1
  175. )
  176. attn_weights.masked_fill_(causal_mask, max_neg_value)
  177. block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
  178. attn_weights -= block_row_maxes
  179. exp_weights = torch.exp(attn_weights)
  180. if exists(row_mask):
  181. exp_weights.masked_fill_(~row_mask, 0.0)
  182. block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
  183. new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
  184. exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc)
  185. exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
  186. exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
  187. new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
  188. oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
  189. row_maxes.copy_(new_row_maxes)
  190. row_sums.copy_(new_row_sums)
  191. ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
  192. ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
  193. return o
  194. @staticmethod
  195. @torch.no_grad()
  196. def backward(ctx, do):
  197. """Algorithm 4 in the paper"""
  198. causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
  199. q, k, v, o, l, m = ctx.saved_tensors
  200. device = q.device
  201. max_neg_value = -torch.finfo(q.dtype).max
  202. qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
  203. dq = torch.zeros_like(q)
  204. dk = torch.zeros_like(k)
  205. dv = torch.zeros_like(v)
  206. row_splits = zip(
  207. q.split(q_bucket_size, dim=-2),
  208. o.split(q_bucket_size, dim=-2),
  209. do.split(q_bucket_size, dim=-2),
  210. mask,
  211. l.split(q_bucket_size, dim=-2),
  212. m.split(q_bucket_size, dim=-2),
  213. dq.split(q_bucket_size, dim=-2),
  214. )
  215. for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
  216. q_start_index = ind * q_bucket_size - qk_len_diff
  217. col_splits = zip(
  218. k.split(k_bucket_size, dim=-2),
  219. v.split(k_bucket_size, dim=-2),
  220. dk.split(k_bucket_size, dim=-2),
  221. dv.split(k_bucket_size, dim=-2),
  222. )
  223. for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
  224. k_start_index = k_ind * k_bucket_size
  225. attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
  226. if causal and q_start_index < (k_start_index + k_bucket_size - 1):
  227. causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
  228. q_start_index - k_start_index + 1
  229. )
  230. attn_weights.masked_fill_(causal_mask, max_neg_value)
  231. exp_attn_weights = torch.exp(attn_weights - mc)
  232. if exists(row_mask):
  233. exp_attn_weights.masked_fill_(~row_mask, 0.0)
  234. p = exp_attn_weights / lc
  235. dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
  236. dp = einsum("... i d, ... j d -> ... i j", doc, vc)
  237. D = (doc * oc).sum(dim=-1, keepdims=True)
  238. ds = p * scale * (dp - D)
  239. dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
  240. dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
  241. dqc.add_(dq_chunk)
  242. dkc.add_(dk_chunk)
  243. dvc.add_(dv_chunk)
  244. return dq, dk, dv, None, None, None, None
  245. def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
  246. if mem_eff_attn:
  247. replace_unet_cross_attn_to_memory_efficient()
  248. elif xformers:
  249. replace_unet_cross_attn_to_xformers()
  250. def replace_unet_cross_attn_to_memory_efficient():
  251. print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention")
  252. flash_func = FlashAttentionFunction
  253. def forward_flash_attn(self, x, context=None, mask=None):
  254. q_bucket_size = 512
  255. k_bucket_size = 1024
  256. h = self.heads
  257. q = self.to_q(x)
  258. context = context if context is not None else x
  259. context = context.to(x.dtype)
  260. if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
  261. context_k, context_v = self.hypernetwork.forward(x, context)
  262. context_k = context_k.to(x.dtype)
  263. context_v = context_v.to(x.dtype)
  264. else:
  265. context_k = context
  266. context_v = context
  267. k = self.to_k(context_k)
  268. v = self.to_v(context_v)
  269. del context, x
  270. q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
  271. out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
  272. out = rearrange(out, "b h n d -> b n (h d)")
  273. # diffusers 0.7.0~
  274. out = self.to_out[0](out)
  275. out = self.to_out[1](out)
  276. return out
  277. diffusers.models.attention.CrossAttention.forward = forward_flash_attn
  278. def replace_unet_cross_attn_to_xformers():
  279. print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers")
  280. try:
  281. import xformers.ops
  282. except ImportError:
  283. raise ImportError("No xformers / xformersがインストールされていないようです")
  284. def forward_xformers(self, x, context=None, mask=None):
  285. h = self.heads
  286. q_in = self.to_q(x)
  287. context = default(context, x)
  288. context = context.to(x.dtype)
  289. if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
  290. context_k, context_v = self.hypernetwork.forward(x, context)
  291. context_k = context_k.to(x.dtype)
  292. context_v = context_v.to(x.dtype)
  293. else:
  294. context_k = context
  295. context_v = context
  296. k_in = self.to_k(context_k)
  297. v_in = self.to_v(context_v)
  298. q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
  299. del q_in, k_in, v_in
  300. q = q.contiguous()
  301. k = k.contiguous()
  302. v = v.contiguous()
  303. out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
  304. out = rearrange(out, "b n h d -> b n (h d)", h=h)
  305. # diffusers 0.7.0~
  306. out = self.to_out[0](out)
  307. out = self.to_out[1](out)
  308. return out
  309. diffusers.models.attention.CrossAttention.forward = forward_xformers
  310. # endregion
  311. # region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正
  312. # https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
  313. # Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正
  314. class PipelineLike:
  315. r"""
  316. Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
  317. weighting in prompt.
  318. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
  319. library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
  320. Args:
  321. vae ([`AutoencoderKL`]):
  322. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
  323. text_encoder ([`CLIPTextModel`]):
  324. Frozen text-encoder. Stable Diffusion uses the text portion of
  325. [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
  326. the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
  327. tokenizer (`CLIPTokenizer`):
  328. Tokenizer of class
  329. [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
  330. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
  331. scheduler ([`SchedulerMixin`]):
  332. A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
  333. [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
  334. safety_checker ([`StableDiffusionSafetyChecker`]):
  335. Classification module that estimates whether generated images could be considered offensive or harmful.
  336. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
  337. feature_extractor ([`CLIPFeatureExtractor`]):
  338. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
  339. """
  340. def __init__(
  341. self,
  342. device,
  343. vae: AutoencoderKL,
  344. text_encoder: CLIPTextModel,
  345. tokenizer: CLIPTokenizer,
  346. unet: UNet2DConditionModel,
  347. scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
  348. clip_skip: int,
  349. clip_model: CLIPModel,
  350. clip_guidance_scale: float,
  351. clip_image_guidance_scale: float,
  352. vgg16_model: torchvision.models.VGG,
  353. vgg16_guidance_scale: float,
  354. vgg16_layer_no: int,
  355. # safety_checker: StableDiffusionSafetyChecker,
  356. # feature_extractor: CLIPFeatureExtractor,
  357. ):
  358. super().__init__()
  359. self.device = device
  360. self.clip_skip = clip_skip
  361. if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
  362. deprecation_message = (
  363. f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
  364. f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
  365. "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
  366. " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
  367. " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
  368. " file"
  369. )
  370. deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
  371. new_config = dict(scheduler.config)
  372. new_config["steps_offset"] = 1
  373. scheduler._internal_dict = FrozenDict(new_config)
  374. if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
  375. deprecation_message = (
  376. f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
  377. " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
  378. " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
  379. " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
  380. " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
  381. )
  382. deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
  383. new_config = dict(scheduler.config)
  384. new_config["clip_sample"] = False
  385. scheduler._internal_dict = FrozenDict(new_config)
  386. self.vae = vae
  387. self.text_encoder = text_encoder
  388. self.tokenizer = tokenizer
  389. self.unet = unet
  390. self.scheduler = scheduler
  391. self.safety_checker = None
  392. # Textual Inversion
  393. self.token_replacements = {}
  394. # XTI
  395. self.token_replacements_XTI = {}
  396. # CLIP guidance
  397. self.clip_guidance_scale = clip_guidance_scale
  398. self.clip_image_guidance_scale = clip_image_guidance_scale
  399. self.clip_model = clip_model
  400. self.normalize = transforms.Normalize(mean=FEATURE_EXTRACTOR_IMAGE_MEAN, std=FEATURE_EXTRACTOR_IMAGE_STD)
  401. self.make_cutouts = MakeCutouts(FEATURE_EXTRACTOR_SIZE)
  402. # VGG16 guidance
  403. self.vgg16_guidance_scale = vgg16_guidance_scale
  404. if self.vgg16_guidance_scale > 0.0:
  405. return_layers = {f"{vgg16_layer_no}": "feat"}
  406. self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(
  407. vgg16_model.features, return_layers=return_layers
  408. )
  409. self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
  410. # ControlNet
  411. self.control_nets: List[ControlNetInfo] = []
  412. # Textual Inversion
  413. def add_token_replacement(self, target_token_id, rep_token_ids):
  414. self.token_replacements[target_token_id] = rep_token_ids
  415. def replace_token(self, tokens, layer=None):
  416. new_tokens = []
  417. for token in tokens:
  418. if token in self.token_replacements:
  419. replacer_ = self.token_replacements[token]
  420. if layer:
  421. replacer = []
  422. for r in replacer_:
  423. if r in self.token_replacements_XTI:
  424. replacer.append(self.token_replacements_XTI[r][layer])
  425. else:
  426. replacer = replacer_
  427. new_tokens.extend(replacer)
  428. else:
  429. new_tokens.append(token)
  430. return new_tokens
  431. def add_token_replacement_XTI(self, target_token_id, rep_token_ids):
  432. self.token_replacements_XTI[target_token_id] = rep_token_ids
  433. def set_control_nets(self, ctrl_nets):
  434. self.control_nets = ctrl_nets
  435. # region xformersとか使う部分:独自に書き換えるので関係なし
  436. def enable_xformers_memory_efficient_attention(self):
  437. r"""
  438. Enable memory efficient attention as implemented in xformers.
  439. When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
  440. time. Speed up at training time is not guaranteed.
  441. Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
  442. is used.
  443. """
  444. self.unet.set_use_memory_efficient_attention_xformers(True)
  445. def disable_xformers_memory_efficient_attention(self):
  446. r"""
  447. Disable memory efficient attention as implemented in xformers.
  448. """
  449. self.unet.set_use_memory_efficient_attention_xformers(False)
  450. def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
  451. r"""
  452. Enable sliced attention computation.
  453. When this option is enabled, the attention module will split the input tensor in slices, to compute attention
  454. in several steps. This is useful to save some memory in exchange for a small speed decrease.
  455. Args:
  456. slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
  457. When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
  458. a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
  459. `attention_head_dim` must be a multiple of `slice_size`.
  460. """
  461. if slice_size == "auto":
  462. # half the attention head size is usually a good trade-off between
  463. # speed and memory
  464. slice_size = self.unet.config.attention_head_dim // 2
  465. self.unet.set_attention_slice(slice_size)
  466. def disable_attention_slicing(self):
  467. r"""
  468. Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
  469. back to computing attention in one step.
  470. """
  471. # set slice_size = `None` to disable `attention slicing`
  472. self.enable_attention_slicing(None)
  473. def enable_sequential_cpu_offload(self):
  474. r"""
  475. Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
  476. text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
  477. `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
  478. """
  479. # accelerateが必要になるのでとりあえず省略
  480. raise NotImplementedError("cpu_offload is omitted.")
  481. # if is_accelerate_available():
  482. # from accelerate import cpu_offload
  483. # else:
  484. # raise ImportError("Please install accelerate via `pip install accelerate`")
  485. # device = self.device
  486. # for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
  487. # if cpu_offloaded_model is not None:
  488. # cpu_offload(cpu_offloaded_model, device)
  489. # endregion
  490. @torch.no_grad()
  491. def __call__(
  492. self,
  493. prompt: Union[str, List[str]],
  494. negative_prompt: Optional[Union[str, List[str]]] = None,
  495. init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
  496. mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
  497. height: int = 512,
  498. width: int = 512,
  499. num_inference_steps: int = 50,
  500. guidance_scale: float = 7.5,
  501. negative_scale: float = None,
  502. strength: float = 0.8,
  503. # num_images_per_prompt: Optional[int] = 1,
  504. eta: float = 0.0,
  505. generator: Optional[torch.Generator] = None,
  506. latents: Optional[torch.FloatTensor] = None,
  507. max_embeddings_multiples: Optional[int] = 3,
  508. output_type: Optional[str] = "pil",
  509. vae_batch_size: float = None,
  510. return_latents: bool = False,
  511. # return_dict: bool = True,
  512. callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
  513. is_cancelled_callback: Optional[Callable[[], bool]] = None,
  514. callback_steps: Optional[int] = 1,
  515. img2img_noise=None,
  516. clip_prompts=None,
  517. clip_guide_images=None,
  518. networks: Optional[List[LoRANetwork]] = None,
  519. **kwargs,
  520. ):
  521. r"""
  522. Function invoked when calling the pipeline for generation.
  523. Args:
  524. prompt (`str` or `List[str]`):
  525. The prompt or prompts to guide the image generation.
  526. negative_prompt (`str` or `List[str]`, *optional*):
  527. The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
  528. if `guidance_scale` is less than `1`).
  529. init_image (`torch.FloatTensor` or `PIL.Image.Image`):
  530. `Image`, or tensor representing an image batch, that will be used as the starting point for the
  531. process.
  532. mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
  533. `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
  534. replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
  535. PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
  536. contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
  537. height (`int`, *optional*, defaults to 512):
  538. The height in pixels of the generated image.
  539. width (`int`, *optional*, defaults to 512):
  540. The width in pixels of the generated image.
  541. num_inference_steps (`int`, *optional*, defaults to 50):
  542. The number of denoising steps. More denoising steps usually lead to a higher quality image at the
  543. expense of slower inference.
  544. guidance_scale (`float`, *optional*, defaults to 7.5):
  545. Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
  546. `guidance_scale` is defined as `w` of equation 2. of [Imagen
  547. Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
  548. 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
  549. usually at the expense of lower image quality.
  550. strength (`float`, *optional*, defaults to 0.8):
  551. Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
  552. `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
  553. number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
  554. noise will be maximum and the denoising process will run for the full number of iterations specified in
  555. `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
  556. num_images_per_prompt (`int`, *optional*, defaults to 1):
  557. The number of images to generate per prompt.
  558. eta (`float`, *optional*, defaults to 0.0):
  559. Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
  560. [`schedulers.DDIMScheduler`], will be ignored for others.
  561. generator (`torch.Generator`, *optional*):
  562. A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
  563. deterministic.
  564. latents (`torch.FloatTensor`, *optional*):
  565. Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
  566. generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
  567. tensor will ge generated by sampling using the supplied random `generator`.
  568. max_embeddings_multiples (`int`, *optional*, defaults to `3`):
  569. The max multiple length of prompt embeddings compared to the max output length of text encoder.
  570. output_type (`str`, *optional*, defaults to `"pil"`):
  571. The output format of the generate image. Choose between
  572. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
  573. return_dict (`bool`, *optional*, defaults to `True`):
  574. Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
  575. plain tuple.
  576. callback (`Callable`, *optional*):
  577. A function that will be called every `callback_steps` steps during inference. The function will be
  578. called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
  579. is_cancelled_callback (`Callable`, *optional*):
  580. A function that will be called every `callback_steps` steps during inference. If the function returns
  581. `True`, the inference will be cancelled.
  582. callback_steps (`int`, *optional*, defaults to 1):
  583. The frequency at which the `callback` function will be called. If not specified, the callback will be
  584. called at every step.
  585. Returns:
  586. `None` if cancelled by `is_cancelled_callback`,
  587. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
  588. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
  589. When returning a tuple, the first element is a list with the generated images, and the second element is a
  590. list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
  591. (nsfw) content, according to the `safety_checker`.
  592. """
  593. num_images_per_prompt = 1 # fixed
  594. if isinstance(prompt, str):
  595. batch_size = 1
  596. prompt = [prompt]
  597. elif isinstance(prompt, list):
  598. batch_size = len(prompt)
  599. else:
  600. raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
  601. reginonal_network = " AND " in prompt[0]
  602. vae_batch_size = (
  603. batch_size
  604. if vae_batch_size is None
  605. else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
  606. )
  607. if strength < 0 or strength > 1:
  608. raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
  609. if height % 8 != 0 or width % 8 != 0:
  610. raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
  611. if (callback_steps is None) or (
  612. callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
  613. ):
  614. raise ValueError(
  615. f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
  616. )
  617. # get prompt text embeddings
  618. # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
  619. # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
  620. # corresponds to doing no classifier free guidance.
  621. do_classifier_free_guidance = guidance_scale > 1.0
  622. if not do_classifier_free_guidance and negative_scale is not None:
  623. print(f"negative_scale is ignored if guidance scalle <= 1.0")
  624. negative_scale = None
  625. # get unconditional embeddings for classifier free guidance
  626. if negative_prompt is None:
  627. negative_prompt = [""] * batch_size
  628. elif isinstance(negative_prompt, str):
  629. negative_prompt = [negative_prompt] * batch_size
  630. if batch_size != len(negative_prompt):
  631. raise ValueError(
  632. f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
  633. f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
  634. " the batch size of `prompt`."
  635. )
  636. if not self.token_replacements_XTI:
  637. text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
  638. pipe=self,
  639. prompt=prompt,
  640. uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
  641. max_embeddings_multiples=max_embeddings_multiples,
  642. clip_skip=self.clip_skip,
  643. **kwargs,
  644. )
  645. if negative_scale is not None:
  646. _, real_uncond_embeddings, _ = get_weighted_text_embeddings(
  647. pipe=self,
  648. prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須
  649. uncond_prompt=[""] * batch_size,
  650. max_embeddings_multiples=max_embeddings_multiples,
  651. clip_skip=self.clip_skip,
  652. **kwargs,
  653. )
  654. if self.token_replacements_XTI:
  655. text_embeddings_concat = []
  656. for layer in [
  657. "IN01",
  658. "IN02",
  659. "IN04",
  660. "IN05",
  661. "IN07",
  662. "IN08",
  663. "MID",
  664. "OUT03",
  665. "OUT04",
  666. "OUT05",
  667. "OUT06",
  668. "OUT07",
  669. "OUT08",
  670. "OUT09",
  671. "OUT10",
  672. "OUT11",
  673. ]:
  674. text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
  675. pipe=self,
  676. prompt=prompt,
  677. uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
  678. max_embeddings_multiples=max_embeddings_multiples,
  679. clip_skip=self.clip_skip,
  680. layer=layer,
  681. **kwargs,
  682. )
  683. if do_classifier_free_guidance:
  684. if negative_scale is None:
  685. text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings]))
  686. else:
  687. text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]))
  688. text_embeddings = torch.stack(text_embeddings_concat)
  689. else:
  690. if do_classifier_free_guidance:
  691. if negative_scale is None:
  692. text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
  693. else:
  694. text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
  695. # CLIP guidanceで使用するembeddingsを取得する
  696. if self.clip_guidance_scale > 0:
  697. clip_text_input = prompt_tokens
  698. if clip_text_input.shape[1] > self.tokenizer.model_max_length:
  699. # TODO 75文字を超えたら警告を出す?
  700. print("trim text input", clip_text_input.shape)
  701. clip_text_input = torch.cat(
  702. [clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1
  703. )
  704. print("trimmed", clip_text_input.shape)
  705. for i, clip_prompt in enumerate(clip_prompts):
  706. if clip_prompt is not None: # clip_promptがあれば上書きする
  707. clip_text_input[i] = self.tokenizer(
  708. clip_prompt,
  709. padding="max_length",
  710. max_length=self.tokenizer.model_max_length,
  711. truncation=True,
  712. return_tensors="pt",
  713. ).input_ids.to(self.device)
  714. text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
  715. text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
  716. if (
  717. self.clip_image_guidance_scale > 0
  718. or self.vgg16_guidance_scale > 0
  719. and clip_guide_images is not None
  720. or self.control_nets
  721. ):
  722. if isinstance(clip_guide_images, PIL.Image.Image):
  723. clip_guide_images = [clip_guide_images]
  724. if self.clip_image_guidance_scale > 0:
  725. clip_guide_images = [preprocess_guide_image(im) for im in clip_guide_images]
  726. clip_guide_images = torch.cat(clip_guide_images, dim=0)
  727. clip_guide_images = self.normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype)
  728. image_embeddings_clip = self.clip_model.get_image_features(clip_guide_images)
  729. image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
  730. if len(image_embeddings_clip) == 1:
  731. image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
  732. elif self.vgg16_guidance_scale > 0:
  733. size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
  734. clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
  735. clip_guide_images = torch.cat(clip_guide_images, dim=0)
  736. clip_guide_images = self.vgg16_normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype)
  737. image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)["feat"]
  738. if len(image_embeddings_vgg16) == 1:
  739. image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
  740. else:
  741. # ControlNetのhintにguide imageを流用する
  742. # 前処理はControlNet側で行う
  743. pass
  744. # set timesteps
  745. self.scheduler.set_timesteps(num_inference_steps, self.device)
  746. latents_dtype = text_embeddings.dtype
  747. init_latents_orig = None
  748. mask = None
  749. if init_image is None:
  750. # get the initial random noise unless the user supplied it
  751. # Unlike in other pipelines, latents need to be generated in the target device
  752. # for 1-to-1 results reproducibility with the CompVis implementation.
  753. # However this currently doesn't work in `mps`.
  754. latents_shape = (
  755. batch_size * num_images_per_prompt,
  756. self.unet.in_channels,
  757. height // 8,
  758. width // 8,
  759. )
  760. if latents is None:
  761. if self.device.type == "mps":
  762. # randn does not exist on mps
  763. latents = torch.randn(
  764. latents_shape,
  765. generator=generator,
  766. device="cpu",
  767. dtype=latents_dtype,
  768. ).to(self.device)
  769. else:
  770. latents = torch.randn(
  771. latents_shape,
  772. generator=generator,
  773. device=self.device,
  774. dtype=latents_dtype,
  775. )
  776. else:
  777. if latents.shape != latents_shape:
  778. raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
  779. latents = latents.to(self.device)
  780. timesteps = self.scheduler.timesteps.to(self.device)
  781. # scale the initial noise by the standard deviation required by the scheduler
  782. latents = latents * self.scheduler.init_noise_sigma
  783. else:
  784. # image to tensor
  785. if isinstance(init_image, PIL.Image.Image):
  786. init_image = [init_image]
  787. if isinstance(init_image[0], PIL.Image.Image):
  788. init_image = [preprocess_image(im) for im in init_image]
  789. init_image = torch.cat(init_image)
  790. if isinstance(init_image, list):
  791. init_image = torch.stack(init_image)
  792. # mask image to tensor
  793. if mask_image is not None:
  794. if isinstance(mask_image, PIL.Image.Image):
  795. mask_image = [mask_image]
  796. if isinstance(mask_image[0], PIL.Image.Image):
  797. mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint
  798. # encode the init image into latents and scale the latents
  799. init_image = init_image.to(device=self.device, dtype=latents_dtype)
  800. if init_image.size()[-2:] == (height // 8, width // 8):
  801. init_latents = init_image
  802. else:
  803. if vae_batch_size >= batch_size:
  804. init_latent_dist = self.vae.encode(init_image).latent_dist
  805. init_latents = init_latent_dist.sample(generator=generator)
  806. else:
  807. if torch.cuda.is_available():
  808. torch.cuda.empty_cache()
  809. init_latents = []
  810. for i in tqdm(range(0, batch_size, vae_batch_size)):
  811. init_latent_dist = self.vae.encode(
  812. init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)
  813. ).latent_dist
  814. init_latents.append(init_latent_dist.sample(generator=generator))
  815. init_latents = torch.cat(init_latents)
  816. init_latents = 0.18215 * init_latents
  817. if len(init_latents) == 1:
  818. init_latents = init_latents.repeat((batch_size, 1, 1, 1))
  819. init_latents_orig = init_latents
  820. # preprocess mask
  821. if mask_image is not None:
  822. mask = mask_image.to(device=self.device, dtype=latents_dtype)
  823. if len(mask) == 1:
  824. mask = mask.repeat((batch_size, 1, 1, 1))
  825. # check sizes
  826. if not mask.shape == init_latents.shape:
  827. raise ValueError("The mask and init_image should be the same size!")
  828. # get the original timestep using init_timestep
  829. offset = self.scheduler.config.get("steps_offset", 0)
  830. init_timestep = int(num_inference_steps * strength) + offset
  831. init_timestep = min(init_timestep, num_inference_steps)
  832. timesteps = self.scheduler.timesteps[-init_timestep]
  833. timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
  834. # add noise to latents using the timesteps
  835. latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps)
  836. t_start = max(num_inference_steps - init_timestep + offset, 0)
  837. timesteps = self.scheduler.timesteps[t_start:].to(self.device)
  838. # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
  839. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
  840. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
  841. # and should be between [0, 1]
  842. accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
  843. extra_step_kwargs = {}
  844. if accepts_eta:
  845. extra_step_kwargs["eta"] = eta
  846. num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
  847. if self.control_nets:
  848. guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
  849. for i, t in enumerate(tqdm(timesteps)):
  850. # expand the latents if we are doing classifier free guidance
  851. latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
  852. latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
  853. # predict the noise residual
  854. if self.control_nets:
  855. if reginonal_network:
  856. num_sub_and_neg_prompts = len(text_embeddings) // batch_size
  857. text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
  858. else:
  859. text_emb_last = text_embeddings
  860. noise_pred = original_control_net.call_unet_and_control_net(
  861. i,
  862. num_latent_input,
  863. self.unet,
  864. self.control_nets,
  865. guided_hints,
  866. i / len(timesteps),
  867. latent_model_input,
  868. t,
  869. text_emb_last,
  870. ).sample
  871. else:
  872. noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
  873. # perform guidance
  874. if do_classifier_free_guidance:
  875. if negative_scale is None:
  876. noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
  877. noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
  878. else:
  879. noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(
  880. num_latent_input
  881. ) # uncond is real uncond
  882. noise_pred = (
  883. noise_pred_uncond
  884. + guidance_scale * (noise_pred_text - noise_pred_uncond)
  885. - negative_scale * (noise_pred_negative - noise_pred_uncond)
  886. )
  887. # perform clip guidance
  888. if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0:
  889. text_embeddings_for_guidance = (
  890. text_embeddings.chunk(num_latent_input)[1] if do_classifier_free_guidance else text_embeddings
  891. )
  892. if self.clip_guidance_scale > 0:
  893. noise_pred, latents = self.cond_fn(
  894. latents,
  895. t,
  896. i,
  897. text_embeddings_for_guidance,
  898. noise_pred,
  899. text_embeddings_clip,
  900. self.clip_guidance_scale,
  901. NUM_CUTOUTS,
  902. USE_CUTOUTS,
  903. )
  904. if self.clip_image_guidance_scale > 0 and clip_guide_images is not None:
  905. noise_pred, latents = self.cond_fn(
  906. latents,
  907. t,
  908. i,
  909. text_embeddings_for_guidance,
  910. noise_pred,
  911. image_embeddings_clip,
  912. self.clip_image_guidance_scale,
  913. NUM_CUTOUTS,
  914. USE_CUTOUTS,
  915. )
  916. if self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
  917. noise_pred, latents = self.cond_fn_vgg16(
  918. latents, t, i, text_embeddings_for_guidance, noise_pred, image_embeddings_vgg16, self.vgg16_guidance_scale
  919. )
  920. # compute the previous noisy sample x_t -> x_t-1
  921. latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
  922. if mask is not None:
  923. # masking
  924. init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t]))
  925. latents = (init_latents_proper * mask) + (latents * (1 - mask))
  926. # call the callback, if provided
  927. if i % callback_steps == 0:
  928. if callback is not None:
  929. callback(i, t, latents)
  930. if is_cancelled_callback is not None and is_cancelled_callback():
  931. return None
  932. if return_latents:
  933. return (latents, False)
  934. latents = 1 / 0.18215 * latents
  935. if vae_batch_size >= batch_size:
  936. image = self.vae.decode(latents).sample
  937. else:
  938. if torch.cuda.is_available():
  939. torch.cuda.empty_cache()
  940. images = []
  941. for i in tqdm(range(0, batch_size, vae_batch_size)):
  942. images.append(
  943. self.vae.decode(latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample
  944. )
  945. image = torch.cat(images)
  946. image = (image / 2 + 0.5).clamp(0, 1)
  947. # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
  948. image = image.cpu().permute(0, 2, 3, 1).float().numpy()
  949. if self.safety_checker is not None:
  950. safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
  951. image, has_nsfw_concept = self.safety_checker(
  952. images=image,
  953. clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
  954. )
  955. else:
  956. has_nsfw_concept = None
  957. if output_type == "pil":
  958. # image = self.numpy_to_pil(image)
  959. image = (image * 255).round().astype("uint8")
  960. image = [Image.fromarray(im) for im in image]
  961. # if not return_dict:
  962. return (image, has_nsfw_concept)
  963. # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
  964. def text2img(
  965. self,
  966. prompt: Union[str, List[str]],
  967. negative_prompt: Optional[Union[str, List[str]]] = None,
  968. height: int = 512,
  969. width: int = 512,
  970. num_inference_steps: int = 50,
  971. guidance_scale: float = 7.5,
  972. num_images_per_prompt: Optional[int] = 1,
  973. eta: float = 0.0,
  974. generator: Optional[torch.Generator] = None,
  975. latents: Optional[torch.FloatTensor] = None,
  976. max_embeddings_multiples: Optional[int] = 3,
  977. output_type: Optional[str] = "pil",
  978. return_dict: bool = True,
  979. callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
  980. callback_steps: Optional[int] = 1,
  981. **kwargs,
  982. ):
  983. r"""
  984. Function for text-to-image generation.
  985. Args:
  986. prompt (`str` or `List[str]`):
  987. The prompt or prompts to guide the image generation.
  988. negative_prompt (`str` or `List[str]`, *optional*):
  989. The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
  990. if `guidance_scale` is less than `1`).
  991. height (`int`, *optional*, defaults to 512):
  992. The height in pixels of the generated image.
  993. width (`int`, *optional*, defaults to 512):
  994. The width in pixels of the generated image.
  995. num_inference_steps (`int`, *optional*, defaults to 50):
  996. The number of denoising steps. More denoising steps usually lead to a higher quality image at the
  997. expense of slower inference.
  998. guidance_scale (`float`, *optional*, defaults to 7.5):
  999. Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
  1000. `guidance_scale` is defined as `w` of equation 2. of [Imagen
  1001. Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
  1002. 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
  1003. usually at the expense of lower image quality.
  1004. num_images_per_prompt (`int`, *optional*, defaults to 1):
  1005. The number of images to generate per prompt.
  1006. eta (`float`, *optional*, defaults to 0.0):
  1007. Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
  1008. [`schedulers.DDIMScheduler`], will be ignored for others.
  1009. generator (`torch.Generator`, *optional*):
  1010. A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
  1011. deterministic.
  1012. latents (`torch.FloatTensor`, *optional*):
  1013. Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
  1014. generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
  1015. tensor will ge generated by sampling using the supplied random `generator`.
  1016. max_embeddings_multiples (`int`, *optional*, defaults to `3`):
  1017. The max multiple length of prompt embeddings compared to the max output length of text encoder.
  1018. output_type (`str`, *optional*, defaults to `"pil"`):
  1019. The output format of the generate image. Choose between
  1020. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
  1021. return_dict (`bool`, *optional*, defaults to `True`):
  1022. Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
  1023. plain tuple.
  1024. callback (`Callable`, *optional*):
  1025. A function that will be called every `callback_steps` steps during inference. The function will be
  1026. called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
  1027. callback_steps (`int`, *optional*, defaults to 1):
  1028. The frequency at which the `callback` function will be called. If not specified, the callback will be
  1029. called at every step.
  1030. Returns:
  1031. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
  1032. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
  1033. When returning a tuple, the first element is a list with the generated images, and the second element is a
  1034. list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
  1035. (nsfw) content, according to the `safety_checker`.
  1036. """
  1037. return self.__call__(
  1038. prompt=prompt,
  1039. negative_prompt=negative_prompt,
  1040. height=height,
  1041. width=width,
  1042. num_inference_steps=num_inference_steps,
  1043. guidance_scale=guidance_scale,
  1044. num_images_per_prompt=num_images_per_prompt,
  1045. eta=eta,
  1046. generator=generator,
  1047. latents=latents,
  1048. max_embeddings_multiples=max_embeddings_multiples,
  1049. output_type=output_type,
  1050. return_dict=return_dict,
  1051. callback=callback,
  1052. callback_steps=callback_steps,
  1053. **kwargs,
  1054. )
  1055. def img2img(
  1056. self,
  1057. init_image: Union[torch.FloatTensor, PIL.Image.Image],
  1058. prompt: Union[str, List[str]],
  1059. negative_prompt: Optional[Union[str, List[str]]] = None,
  1060. strength: float = 0.8,
  1061. num_inference_steps: Optional[int] = 50,
  1062. guidance_scale: Optional[float] = 7.5,
  1063. num_images_per_prompt: Optional[int] = 1,
  1064. eta: Optional[float] = 0.0,
  1065. generator: Optional[torch.Generator] = None,
  1066. max_embeddings_multiples: Optional[int] = 3,
  1067. output_type: Optional[str] = "pil",
  1068. return_dict: bool = True,
  1069. callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
  1070. callback_steps: Optional[int] = 1,
  1071. **kwargs,
  1072. ):
  1073. r"""
  1074. Function for image-to-image generation.
  1075. Args:
  1076. init_image (`torch.FloatTensor` or `PIL.Image.Image`):
  1077. `Image`, or tensor representing an image batch, that will be used as the starting point for the
  1078. process.
  1079. prompt (`str` or `List[str]`):
  1080. The prompt or prompts to guide the image generation.
  1081. negative_prompt (`str` or `List[str]`, *optional*):
  1082. The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
  1083. if `guidance_scale` is less than `1`).
  1084. strength (`float`, *optional*, defaults to 0.8):
  1085. Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
  1086. `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
  1087. number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
  1088. noise will be maximum and the denoising process will run for the full number of iterations specified in
  1089. `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
  1090. num_inference_steps (`int`, *optional*, defaults to 50):
  1091. The number of denoising steps. More denoising steps usually lead to a higher quality image at the
  1092. expense of slower inference. This parameter will be modulated by `strength`.
  1093. guidance_scale (`float`, *optional*, defaults to 7.5):
  1094. Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
  1095. `guidance_scale` is defined as `w` of equation 2. of [Imagen
  1096. Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
  1097. 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
  1098. usually at the expense of lower image quality.
  1099. num_images_per_prompt (`int`, *optional*, defaults to 1):
  1100. The number of images to generate per prompt.
  1101. eta (`float`, *optional*, defaults to 0.0):
  1102. Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
  1103. [`schedulers.DDIMScheduler`], will be ignored for others.
  1104. generator (`torch.Generator`, *optional*):
  1105. A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
  1106. deterministic.
  1107. max_embeddings_multiples (`int`, *optional*, defaults to `3`):
  1108. The max multiple length of prompt embeddings compared to the max output length of text encoder.
  1109. output_type (`str`, *optional*, defaults to `"pil"`):
  1110. The output format of the generate image. Choose between
  1111. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
  1112. return_dict (`bool`, *optional*, defaults to `True`):
  1113. Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
  1114. plain tuple.
  1115. callback (`Callable`, *optional*):
  1116. A function that will be called every `callback_steps` steps during inference. The function will be
  1117. called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
  1118. callback_steps (`int`, *optional*, defaults to 1):
  1119. The frequency at which the `callback` function will be called. If not specified, the callback will be
  1120. called at every step.
  1121. Returns:
  1122. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
  1123. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
  1124. When returning a tuple, the first element is a list with the generated images, and the second element is a
  1125. list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
  1126. (nsfw) content, according to the `safety_checker`.
  1127. """
  1128. return self.__call__(
  1129. prompt=prompt,
  1130. negative_prompt=negative_prompt,
  1131. init_image=init_image,
  1132. num_inference_steps=num_inference_steps,
  1133. guidance_scale=guidance_scale,
  1134. strength=strength,
  1135. num_images_per_prompt=num_images_per_prompt,
  1136. eta=eta,
  1137. generator=generator,
  1138. max_embeddings_multiples=max_embeddings_multiples,
  1139. output_type=output_type,
  1140. return_dict=return_dict,
  1141. callback=callback,
  1142. callback_steps=callback_steps,
  1143. **kwargs,
  1144. )
  1145. def inpaint(
  1146. self,
  1147. init_image: Union[torch.FloatTensor, PIL.Image.Image],
  1148. mask_image: Union[torch.FloatTensor, PIL.Image.Image],
  1149. prompt: Union[str, List[str]],
  1150. negative_prompt: Optional[Union[str, List[str]]] = None,
  1151. strength: float = 0.8,
  1152. num_inference_steps: Optional[int] = 50,
  1153. guidance_scale: Optional[float] = 7.5,
  1154. num_images_per_prompt: Optional[int] = 1,
  1155. eta: Optional[float] = 0.0,
  1156. generator: Optional[torch.Generator] = None,
  1157. max_embeddings_multiples: Optional[int] = 3,
  1158. output_type: Optional[str] = "pil",
  1159. return_dict: bool = True,
  1160. callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
  1161. callback_steps: Optional[int] = 1,
  1162. **kwargs,
  1163. ):
  1164. r"""
  1165. Function for inpaint.
  1166. Args:
  1167. init_image (`torch.FloatTensor` or `PIL.Image.Image`):
  1168. `Image`, or tensor representing an image batch, that will be used as the starting point for the
  1169. process. This is the image whose masked region will be inpainted.
  1170. mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
  1171. `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
  1172. replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
  1173. PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
  1174. contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
  1175. prompt (`str` or `List[str]`):
  1176. The prompt or prompts to guide the image generation.
  1177. negative_prompt (`str` or `List[str]`, *optional*):
  1178. The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
  1179. if `guidance_scale` is less than `1`).
  1180. strength (`float`, *optional*, defaults to 0.8):
  1181. Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
  1182. is 1, the denoising process will be run on the masked area for the full number of iterations specified
  1183. in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
  1184. noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
  1185. num_inference_steps (`int`, *optional*, defaults to 50):
  1186. The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
  1187. the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
  1188. guidance_scale (`float`, *optional*, defaults to 7.5):
  1189. Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
  1190. `guidance_scale` is defined as `w` of equation 2. of [Imagen
  1191. Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
  1192. 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
  1193. usually at the expense of lower image quality.
  1194. num_images_per_prompt (`int`, *optional*, defaults to 1):
  1195. The number of images to generate per prompt.
  1196. eta (`float`, *optional*, defaults to 0.0):
  1197. Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
  1198. [`schedulers.DDIMScheduler`], will be ignored for others.
  1199. generator (`torch.Generator`, *optional*):
  1200. A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
  1201. deterministic.
  1202. max_embeddings_multiples (`int`, *optional*, defaults to `3`):
  1203. The max multiple length of prompt embeddings compared to the max output length of text encoder.
  1204. output_type (`str`, *optional*, defaults to `"pil"`):
  1205. The output format of the generate image. Choose between
  1206. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
  1207. return_dict (`bool`, *optional*, defaults to `True`):
  1208. Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
  1209. plain tuple.
  1210. callback (`Callable`, *optional*):
  1211. A function that will be called every `callback_steps` steps during inference. The function will be
  1212. called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
  1213. callback_steps (`int`, *optional*, defaults to 1):
  1214. The frequency at which the `callback` function will be called. If not specified, the callback will be
  1215. called at every step.
  1216. Returns:
  1217. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
  1218. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
  1219. When returning a tuple, the first element is a list with the generated images, and the second element is a
  1220. list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
  1221. (nsfw) content, according to the `safety_checker`.
  1222. """
  1223. return self.__call__(
  1224. prompt=prompt,
  1225. negative_prompt=negative_prompt,
  1226. init_image=init_image,
  1227. mask_image=mask_image,
  1228. num_inference_steps=num_inference_steps,
  1229. guidance_scale=guidance_scale,
  1230. strength=strength,
  1231. num_images_per_prompt=num_images_per_prompt,
  1232. eta=eta,
  1233. generator=generator,
  1234. max_embeddings_multiples=max_embeddings_multiples,
  1235. output_type=output_type,
  1236. return_dict=return_dict,
  1237. callback=callback,
  1238. callback_steps=callback_steps,
  1239. **kwargs,
  1240. )
  1241. # CLIP guidance StableDiffusion
  1242. # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py
  1243. # バッチを分解して1件ずつ処理する
  1244. def cond_fn(
  1245. self,
  1246. latents,
  1247. timestep,
  1248. index,
  1249. text_embeddings,
  1250. noise_pred_original,
  1251. guide_embeddings_clip,
  1252. clip_guidance_scale,
  1253. num_cutouts,
  1254. use_cutouts=True,
  1255. ):
  1256. if len(latents) == 1:
  1257. return self.cond_fn1(
  1258. latents,
  1259. timestep,
  1260. index,
  1261. text_embeddings,
  1262. noise_pred_original,
  1263. guide_embeddings_clip,
  1264. clip_guidance_scale,
  1265. num_cutouts,
  1266. use_cutouts,
  1267. )
  1268. noise_pred = []
  1269. cond_latents = []
  1270. for i in range(len(latents)):
  1271. lat1 = latents[i].unsqueeze(0)
  1272. tem1 = text_embeddings[i].unsqueeze(0)
  1273. npo1 = noise_pred_original[i].unsqueeze(0)
  1274. gem1 = guide_embeddings_clip[i].unsqueeze(0)
  1275. npr1, cla1 = self.cond_fn1(lat1, timestep, index, tem1, npo1, gem1, clip_guidance_scale, num_cutouts, use_cutouts)
  1276. noise_pred.append(npr1)
  1277. cond_latents.append(cla1)
  1278. noise_pred = torch.cat(noise_pred)
  1279. cond_latents = torch.cat(cond_latents)
  1280. return noise_pred, cond_latents
  1281. @torch.enable_grad()
  1282. def cond_fn1(
  1283. self,
  1284. latents,
  1285. timestep,
  1286. index,
  1287. text_embeddings,
  1288. noise_pred_original,
  1289. guide_embeddings_clip,
  1290. clip_guidance_scale,
  1291. num_cutouts,
  1292. use_cutouts=True,
  1293. ):
  1294. latents = latents.detach().requires_grad_()
  1295. if isinstance(self.scheduler, LMSDiscreteScheduler):
  1296. sigma = self.scheduler.sigmas[index]
  1297. # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
  1298. latent_model_input = latents / ((sigma**2 + 1) ** 0.5)
  1299. else:
  1300. latent_model_input = latents
  1301. # predict the noise residual
  1302. noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
  1303. if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)):
  1304. alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
  1305. beta_prod_t = 1 - alpha_prod_t
  1306. # compute predicted original sample from predicted noise also called
  1307. # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
  1308. pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
  1309. fac = torch.sqrt(beta_prod_t)
  1310. sample = pred_original_sample * (fac) + latents * (1 - fac)
  1311. elif isinstance(self.scheduler, LMSDiscreteScheduler):
  1312. sigma = self.scheduler.sigmas[index]
  1313. sample = latents - sigma * noise_pred
  1314. else:
  1315. raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
  1316. sample = 1 / 0.18215 * sample
  1317. image = self.vae.decode(sample).sample
  1318. image = (image / 2 + 0.5).clamp(0, 1)
  1319. if use_cutouts:
  1320. image = self.make_cutouts(image, num_cutouts)
  1321. else:
  1322. image = transforms.Resize(FEATURE_EXTRACTOR_SIZE)(image)
  1323. image = self.normalize(image).to(latents.dtype)
  1324. image_embeddings_clip = self.clip_model.get_image_features(image)
  1325. image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
  1326. if use_cutouts:
  1327. dists = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip)
  1328. dists = dists.view([num_cutouts, sample.shape[0], -1])
  1329. loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
  1330. else:
  1331. # バッチサイズが複数だと正しく動くかわからない
  1332. loss = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip).mean() * clip_guidance_scale
  1333. grads = -torch.autograd.grad(loss, latents)[0]
  1334. if isinstance(self.scheduler, LMSDiscreteScheduler):
  1335. latents = latents.detach() + grads * (sigma**2)
  1336. noise_pred = noise_pred_original
  1337. else:
  1338. noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
  1339. return noise_pred, latents
  1340. # バッチを分解して一件ずつ処理する
  1341. def cond_fn_vgg16(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale):
  1342. if len(latents) == 1:
  1343. return self.cond_fn_vgg16_b1(
  1344. latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale
  1345. )
  1346. noise_pred = []
  1347. cond_latents = []
  1348. for i in range(len(latents)):
  1349. lat1 = latents[i].unsqueeze(0)
  1350. tem1 = text_embeddings[i].unsqueeze(0)
  1351. npo1 = noise_pred_original[i].unsqueeze(0)
  1352. gem1 = guide_embeddings[i].unsqueeze(0)
  1353. npr1, cla1 = self.cond_fn_vgg16_b1(lat1, timestep, index, tem1, npo1, gem1, guidance_scale)
  1354. noise_pred.append(npr1)
  1355. cond_latents.append(cla1)
  1356. noise_pred = torch.cat(noise_pred)
  1357. cond_latents = torch.cat(cond_latents)
  1358. return noise_pred, cond_latents
  1359. # 1件だけ処理する
  1360. @torch.enable_grad()
  1361. def cond_fn_vgg16_b1(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale):
  1362. latents = latents.detach().requires_grad_()
  1363. if isinstance(self.scheduler, LMSDiscreteScheduler):
  1364. sigma = self.scheduler.sigmas[index]
  1365. # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
  1366. latent_model_input = latents / ((sigma**2 + 1) ** 0.5)
  1367. else:
  1368. latent_model_input = latents
  1369. # predict the noise residual
  1370. noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
  1371. if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)):
  1372. alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
  1373. beta_prod_t = 1 - alpha_prod_t
  1374. # compute predicted original sample from predicted noise also called
  1375. # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
  1376. pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
  1377. fac = torch.sqrt(beta_prod_t)
  1378. sample = pred_original_sample * (fac) + latents * (1 - fac)
  1379. elif isinstance(self.scheduler, LMSDiscreteScheduler):
  1380. sigma = self.scheduler.sigmas[index]
  1381. sample = latents - sigma * noise_pred
  1382. else:
  1383. raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
  1384. sample = 1 / 0.18215 * sample
  1385. image = self.vae.decode(sample).sample
  1386. image = (image / 2 + 0.5).clamp(0, 1)
  1387. image = transforms.Resize((image.shape[-2] // VGG16_INPUT_RESIZE_DIV, image.shape[-1] // VGG16_INPUT_RESIZE_DIV))(image)
  1388. image = self.vgg16_normalize(image).to(latents.dtype)
  1389. image_embeddings = self.vgg16_feat_model(image)["feat"]
  1390. # バッチサイズが複数だと正しく動くかわからない
  1391. loss = ((image_embeddings - guide_embeddings) ** 2).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので
  1392. grads = -torch.autograd.grad(loss, latents)[0]
  1393. if isinstance(self.scheduler, LMSDiscreteScheduler):
  1394. latents = latents.detach() + grads * (sigma**2)
  1395. noise_pred = noise_pred_original
  1396. else:
  1397. noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
  1398. return noise_pred, latents
  1399. class MakeCutouts(torch.nn.Module):
  1400. def __init__(self, cut_size, cut_power=1.0):
  1401. super().__init__()
  1402. self.cut_size = cut_size
  1403. self.cut_power = cut_power
  1404. def forward(self, pixel_values, num_cutouts):
  1405. sideY, sideX = pixel_values.shape[2:4]
  1406. max_size = min(sideX, sideY)
  1407. min_size = min(sideX, sideY, self.cut_size)
  1408. cutouts = []
  1409. for _ in range(num_cutouts):
  1410. size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
  1411. offsetx = torch.randint(0, sideX - size + 1, ())
  1412. offsety = torch.randint(0, sideY - size + 1, ())
  1413. cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]
  1414. cutouts.append(torch.nn.functional.adaptive_avg_pool2d(cutout, self.cut_size))
  1415. return torch.cat(cutouts)
  1416. def spherical_dist_loss(x, y):
  1417. x = torch.nn.functional.normalize(x, dim=-1)
  1418. y = torch.nn.functional.normalize(y, dim=-1)
  1419. return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
  1420. re_attention = re.compile(
  1421. r"""
  1422. \\\(|
  1423. \\\)|
  1424. \\\[|
  1425. \\]|
  1426. \\\\|
  1427. \\|
  1428. \(|
  1429. \[|
  1430. :([+-]?[.\d]+)\)|
  1431. \)|
  1432. ]|
  1433. [^\\()\[\]:]+|
  1434. :
  1435. """,
  1436. re.X,
  1437. )
  1438. def parse_prompt_attention(text):
  1439. """
  1440. Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
  1441. Accepted tokens are:
  1442. (abc) - increases attention to abc by a multiplier of 1.1
  1443. (abc:3.12) - increases attention to abc by a multiplier of 3.12
  1444. [abc] - decreases attention to abc by a multiplier of 1.1
  1445. \( - literal character '('
  1446. \[ - literal character '['
  1447. \) - literal character ')'
  1448. \] - literal character ']'
  1449. \\ - literal character '\'
  1450. anything else - just text
  1451. >>> parse_prompt_attention('normal text')
  1452. [['normal text', 1.0]]
  1453. >>> parse_prompt_attention('an (important) word')
  1454. [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
  1455. >>> parse_prompt_attention('(unbalanced')
  1456. [['unbalanced', 1.1]]
  1457. >>> parse_prompt_attention('\(literal\]')
  1458. [['(literal]', 1.0]]
  1459. >>> parse_prompt_attention('(unnecessary)(parens)')
  1460. [['unnecessaryparens', 1.1]]
  1461. >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
  1462. [['a ', 1.0],
  1463. ['house', 1.5730000000000004],
  1464. [' ', 1.1],
  1465. ['on', 1.0],
  1466. [' a ', 1.1],
  1467. ['hill', 0.55],
  1468. [', sun, ', 1.1],
  1469. ['sky', 1.4641000000000006],
  1470. ['.', 1.1]]
  1471. """
  1472. res = []
  1473. round_brackets = []
  1474. square_brackets = []
  1475. round_bracket_multiplier = 1.1
  1476. square_bracket_multiplier = 1 / 1.1
  1477. def multiply_range(start_position, multiplier):
  1478. for p in range(start_position, len(res)):
  1479. res[p][1] *= multiplier
  1480. for m in re_attention.finditer(text):
  1481. text = m.group(0)
  1482. weight = m.group(1)
  1483. if text.startswith("\\"):
  1484. res.append([text[1:], 1.0])
  1485. elif text == "(":
  1486. round_brackets.append(len(res))
  1487. elif text == "[":
  1488. square_brackets.append(len(res))
  1489. elif weight is not None and len(round_brackets) > 0:
  1490. multiply_range(round_brackets.pop(), float(weight))
  1491. elif text == ")" and len(round_brackets) > 0:
  1492. multiply_range(round_brackets.pop(), round_bracket_multiplier)
  1493. elif text == "]" and len(square_brackets) > 0:
  1494. multiply_range(square_brackets.pop(), square_bracket_multiplier)
  1495. else:
  1496. res.append([text, 1.0])
  1497. for pos in round_brackets:
  1498. multiply_range(pos, round_bracket_multiplier)
  1499. for pos in square_brackets:
  1500. multiply_range(pos, square_bracket_multiplier)
  1501. if len(res) == 0:
  1502. res = [["", 1.0]]
  1503. # merge runs of identical weights
  1504. i = 0
  1505. while i + 1 < len(res):
  1506. if res[i][1] == res[i + 1][1]:
  1507. res[i][0] += res[i + 1][0]
  1508. res.pop(i + 1)
  1509. else:
  1510. i += 1
  1511. return res
  1512. def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int, layer=None):
  1513. r"""
  1514. Tokenize a list of prompts and return its tokens with weights of each token.
  1515. No padding, starting or ending token is included.
  1516. """
  1517. tokens = []
  1518. weights = []
  1519. truncated = False
  1520. for text in prompt:
  1521. texts_and_weights = parse_prompt_attention(text)
  1522. text_token = []
  1523. text_weight = []
  1524. for word, weight in texts_and_weights:
  1525. # tokenize and discard the starting and the ending token
  1526. token = pipe.tokenizer(word).input_ids[1:-1]
  1527. token = pipe.replace_token(token, layer=layer)
  1528. text_token += token
  1529. # copy the weight by length of token
  1530. text_weight += [weight] * len(token)
  1531. # stop if the text is too long (longer than truncation limit)
  1532. if len(text_token) > max_length:
  1533. truncated = True
  1534. break
  1535. # truncate
  1536. if len(text_token) > max_length:
  1537. truncated = True
  1538. text_token = text_token[:max_length]
  1539. text_weight = text_weight[:max_length]
  1540. tokens.append(text_token)
  1541. weights.append(text_weight)
  1542. if truncated:
  1543. print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
  1544. return tokens, weights
  1545. def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
  1546. r"""
  1547. Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
  1548. """
  1549. max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
  1550. weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
  1551. for i in range(len(tokens)):
  1552. tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
  1553. if no_boseos_middle:
  1554. weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
  1555. else:
  1556. w = []
  1557. if len(weights[i]) == 0:
  1558. w = [1.0] * weights_length
  1559. else:
  1560. for j in range(max_embeddings_multiples):
  1561. w.append(1.0) # weight for starting token in this chunk
  1562. w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
  1563. w.append(1.0) # weight for ending token in this chunk
  1564. w += [1.0] * (weights_length - len(w))
  1565. weights[i] = w[:]
  1566. return tokens, weights
  1567. def get_unweighted_text_embeddings(
  1568. pipe: PipelineLike,
  1569. text_input: torch.Tensor,
  1570. chunk_length: int,
  1571. clip_skip: int,
  1572. eos: int,
  1573. pad: int,
  1574. no_boseos_middle: Optional[bool] = True,
  1575. ):
  1576. """
  1577. When the length of tokens is a multiple of the capacity of the text encoder,
  1578. it should be split into chunks and sent to the text encoder individually.
  1579. """
  1580. max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
  1581. if max_embeddings_multiples > 1:
  1582. text_embeddings = []
  1583. for i in range(max_embeddings_multiples):
  1584. # extract the i-th chunk
  1585. text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
  1586. # cover the head and the tail by the starting and the ending tokens
  1587. text_input_chunk[:, 0] = text_input[0, 0]
  1588. if pad == eos: # v1
  1589. text_input_chunk[:, -1] = text_input[0, -1]
  1590. else: # v2
  1591. for j in range(len(text_input_chunk)):
  1592. if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
  1593. text_input_chunk[j, -1] = eos
  1594. if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
  1595. text_input_chunk[j, 1] = eos
  1596. if clip_skip is None or clip_skip == 1:
  1597. text_embedding = pipe.text_encoder(text_input_chunk)[0]
  1598. else:
  1599. enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
  1600. text_embedding = enc_out["hidden_states"][-clip_skip]
  1601. text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
  1602. if no_boseos_middle:
  1603. if i == 0:
  1604. # discard the ending token
  1605. text_embedding = text_embedding[:, :-1]
  1606. elif i == max_embeddings_multiples - 1:
  1607. # discard the starting token
  1608. text_embedding = text_embedding[:, 1:]
  1609. else:
  1610. # discard both starting and ending tokens
  1611. text_embedding = text_embedding[:, 1:-1]
  1612. text_embeddings.append(text_embedding)
  1613. text_embeddings = torch.concat(text_embeddings, axis=1)
  1614. else:
  1615. if clip_skip is None or clip_skip == 1:
  1616. text_embeddings = pipe.text_encoder(text_input)[0]
  1617. else:
  1618. enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
  1619. text_embeddings = enc_out["hidden_states"][-clip_skip]
  1620. text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
  1621. return text_embeddings
  1622. def get_weighted_text_embeddings(
  1623. pipe: PipelineLike,
  1624. prompt: Union[str, List[str]],
  1625. uncond_prompt: Optional[Union[str, List[str]]] = None,
  1626. max_embeddings_multiples: Optional[int] = 1,
  1627. no_boseos_middle: Optional[bool] = False,
  1628. skip_parsing: Optional[bool] = False,
  1629. skip_weighting: Optional[bool] = False,
  1630. clip_skip=None,
  1631. layer=None,
  1632. **kwargs,
  1633. ):
  1634. r"""
  1635. Prompts can be assigned with local weights using brackets. For example,
  1636. prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
  1637. and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
  1638. Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
  1639. Args:
  1640. pipe (`DiffusionPipeline`):
  1641. Pipe to provide access to the tokenizer and the text encoder.
  1642. prompt (`str` or `List[str]`):
  1643. The prompt or prompts to guide the image generation.
  1644. uncond_prompt (`str` or `List[str]`):
  1645. The unconditional prompt or prompts for guide the image generation. If unconditional prompt
  1646. is provided, the embeddings of prompt and uncond_prompt are concatenated.
  1647. max_embeddings_multiples (`int`, *optional*, defaults to `1`):
  1648. The max multiple length of prompt embeddings compared to the max output length of text encoder.
  1649. no_boseos_middle (`bool`, *optional*, defaults to `False`):
  1650. If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
  1651. ending token in each of the chunk in the middle.
  1652. skip_parsing (`bool`, *optional*, defaults to `False`):
  1653. Skip the parsing of brackets.
  1654. skip_weighting (`bool`, *optional*, defaults to `False`):
  1655. Skip the weighting. When the parsing is skipped, it is forced True.
  1656. """
  1657. max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
  1658. if isinstance(prompt, str):
  1659. prompt = [prompt]
  1660. # split the prompts with "AND". each prompt must have the same number of splits
  1661. new_prompts = []
  1662. for p in prompt:
  1663. new_prompts.extend(p.split(" AND "))
  1664. prompt = new_prompts
  1665. if not skip_parsing:
  1666. prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
  1667. if uncond_prompt is not None:
  1668. if isinstance(uncond_prompt, str):
  1669. uncond_prompt = [uncond_prompt]
  1670. uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2, layer=layer)
  1671. else:
  1672. prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
  1673. prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
  1674. if uncond_prompt is not None:
  1675. if isinstance(uncond_prompt, str):
  1676. uncond_prompt = [uncond_prompt]
  1677. uncond_tokens = [
  1678. token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
  1679. ]
  1680. uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
  1681. # round up the longest length of tokens to a multiple of (model_max_length - 2)
  1682. max_length = max([len(token) for token in prompt_tokens])
  1683. if uncond_prompt is not None:
  1684. max_length = max(max_length, max([len(token) for token in uncond_tokens]))
  1685. max_embeddings_multiples = min(
  1686. max_embeddings_multiples,
  1687. (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
  1688. )
  1689. max_embeddings_multiples = max(1, max_embeddings_multiples)
  1690. max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
  1691. # pad the length of tokens and weights
  1692. bos = pipe.tokenizer.bos_token_id
  1693. eos = pipe.tokenizer.eos_token_id
  1694. pad = pipe.tokenizer.pad_token_id
  1695. prompt_tokens, prompt_weights = pad_tokens_and_weights(
  1696. prompt_tokens,
  1697. prompt_weights,
  1698. max_length,
  1699. bos,
  1700. eos,
  1701. pad,
  1702. no_boseos_middle=no_boseos_middle,
  1703. chunk_length=pipe.tokenizer.model_max_length,
  1704. )
  1705. prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
  1706. if uncond_prompt is not None:
  1707. uncond_tokens, uncond_weights = pad_tokens_and_weights(
  1708. uncond_tokens,
  1709. uncond_weights,
  1710. max_length,
  1711. bos,
  1712. eos,
  1713. pad,
  1714. no_boseos_middle=no_boseos_middle,
  1715. chunk_length=pipe.tokenizer.model_max_length,
  1716. )
  1717. uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
  1718. # get the embeddings
  1719. text_embeddings = get_unweighted_text_embeddings(
  1720. pipe,
  1721. prompt_tokens,
  1722. pipe.tokenizer.model_max_length,
  1723. clip_skip,
  1724. eos,
  1725. pad,
  1726. no_boseos_middle=no_boseos_middle,
  1727. )
  1728. prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
  1729. if uncond_prompt is not None:
  1730. uncond_embeddings = get_unweighted_text_embeddings(
  1731. pipe,
  1732. uncond_tokens,
  1733. pipe.tokenizer.model_max_length,
  1734. clip_skip,
  1735. eos,
  1736. pad,
  1737. no_boseos_middle=no_boseos_middle,
  1738. )
  1739. uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
  1740. # assign weights to the prompts and normalize in the sense of mean
  1741. # TODO: should we normalize by chunk or in a whole (current implementation)?
  1742. # →全体でいいんじゃないかな
  1743. if (not skip_parsing) and (not skip_weighting):
  1744. previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
  1745. text_embeddings *= prompt_weights.unsqueeze(-1)
  1746. current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
  1747. text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
  1748. if uncond_prompt is not None:
  1749. previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
  1750. uncond_embeddings *= uncond_weights.unsqueeze(-1)
  1751. current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
  1752. uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
  1753. if uncond_prompt is not None:
  1754. return text_embeddings, uncond_embeddings, prompt_tokens
  1755. return text_embeddings, None, prompt_tokens
  1756. def preprocess_guide_image(image):
  1757. image = image.resize(FEATURE_EXTRACTOR_SIZE, resample=Image.NEAREST) # cond_fnと合わせる
  1758. image = np.array(image).astype(np.float32) / 255.0
  1759. image = image[None].transpose(0, 3, 1, 2) # nchw
  1760. image = torch.from_numpy(image)
  1761. return image # 0 to 1
  1762. # VGG16の入力は任意サイズでよいので入力画像を適宜リサイズする
  1763. def preprocess_vgg16_guide_image(image, size):
  1764. image = image.resize(size, resample=Image.NEAREST) # cond_fnと合わせる
  1765. image = np.array(image).astype(np.float32) / 255.0
  1766. image = image[None].transpose(0, 3, 1, 2) # nchw
  1767. image = torch.from_numpy(image)
  1768. return image # 0 to 1
  1769. def preprocess_image(image):
  1770. w, h = image.size
  1771. w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
  1772. image = image.resize((w, h), resample=PIL.Image.LANCZOS)
  1773. image = np.array(image).astype(np.float32) / 255.0
  1774. image = image[None].transpose(0, 3, 1, 2)
  1775. image = torch.from_numpy(image)
  1776. return 2.0 * image - 1.0
  1777. def preprocess_mask(mask):
  1778. mask = mask.convert("L")
  1779. w, h = mask.size
  1780. w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
  1781. mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
  1782. mask = np.array(mask).astype(np.float32) / 255.0
  1783. mask = np.tile(mask, (4, 1, 1))
  1784. mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
  1785. mask = 1 - mask # repaint white, keep black
  1786. mask = torch.from_numpy(mask)
  1787. return mask
  1788. # endregion
  1789. # def load_clip_l14_336(dtype):
  1790. # print(f"loading CLIP: {CLIP_ID_L14_336}")
  1791. # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype)
  1792. # return text_encoder
  1793. class BatchDataBase(NamedTuple):
  1794. # バッチ分割が必要ないデータ
  1795. step: int
  1796. prompt: str
  1797. negative_prompt: str
  1798. seed: int
  1799. init_image: Any
  1800. mask_image: Any
  1801. clip_prompt: str
  1802. guide_image: Any
  1803. class BatchDataExt(NamedTuple):
  1804. # バッチ分割が必要なデータ
  1805. width: int
  1806. height: int
  1807. steps: int
  1808. scale: float
  1809. negative_scale: float
  1810. strength: float
  1811. network_muls: Tuple[float]
  1812. num_sub_prompts: int
  1813. class BatchData(NamedTuple):
  1814. return_latents: bool
  1815. base: BatchDataBase
  1816. ext: BatchDataExt
  1817. def main(args):
  1818. if args.fp16:
  1819. dtype = torch.float16
  1820. elif args.bf16:
  1821. dtype = torch.bfloat16
  1822. else:
  1823. dtype = torch.float32
  1824. highres_fix = args.highres_fix_scale is not None
  1825. assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
  1826. if args.v_parameterization and not args.v2:
  1827. print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
  1828. if args.v2 and args.clip_skip is not None:
  1829. print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
  1830. # モデルを読み込む
  1831. if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
  1832. files = glob.glob(args.ckpt)
  1833. if len(files) == 1:
  1834. args.ckpt = files[0]
  1835. use_stable_diffusion_format = os.path.isfile(args.ckpt)
  1836. if use_stable_diffusion_format:
  1837. print("load StableDiffusion checkpoint")
  1838. text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
  1839. else:
  1840. print("load Diffusers pretrained models")
  1841. loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
  1842. text_encoder = loading_pipe.text_encoder
  1843. vae = loading_pipe.vae
  1844. unet = loading_pipe.unet
  1845. tokenizer = loading_pipe.tokenizer
  1846. del loading_pipe
  1847. # VAEを読み込む
  1848. if args.vae is not None:
  1849. vae = model_util.load_vae(args.vae, dtype)
  1850. print("additional VAE loaded")
  1851. # # 置換するCLIPを読み込む
  1852. # if args.replace_clip_l14_336:
  1853. # text_encoder = load_clip_l14_336(dtype)
  1854. # print(f"large clip {CLIP_ID_L14_336} is loaded")
  1855. if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale:
  1856. print("prepare clip model")
  1857. clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype)
  1858. else:
  1859. clip_model = None
  1860. if args.vgg16_guidance_scale > 0.0:
  1861. print("prepare resnet model")
  1862. vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1)
  1863. else:
  1864. vgg16_model = None
  1865. # xformers、Hypernetwork対応
  1866. if not args.diffusers_xformers:
  1867. replace_unet_modules(unet, not args.xformers, args.xformers)
  1868. # tokenizerを読み込む
  1869. print("loading tokenizer")
  1870. if use_stable_diffusion_format:
  1871. tokenizer = train_util.load_tokenizer(args)
  1872. # schedulerを用意する
  1873. sched_init_args = {}
  1874. scheduler_num_noises_per_step = 1
  1875. if args.sampler == "ddim":
  1876. scheduler_cls = DDIMScheduler
  1877. scheduler_module = diffusers.schedulers.scheduling_ddim
  1878. elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
  1879. scheduler_cls = DDPMScheduler
  1880. scheduler_module = diffusers.schedulers.scheduling_ddpm
  1881. elif args.sampler == "pndm":
  1882. scheduler_cls = PNDMScheduler
  1883. scheduler_module = diffusers.schedulers.scheduling_pndm
  1884. elif args.sampler == "lms" or args.sampler == "k_lms":
  1885. scheduler_cls = LMSDiscreteScheduler
  1886. scheduler_module = diffusers.schedulers.scheduling_lms_discrete
  1887. elif args.sampler == "euler" or args.sampler == "k_euler":
  1888. scheduler_cls = EulerDiscreteScheduler
  1889. scheduler_module = diffusers.schedulers.scheduling_euler_discrete
  1890. elif args.sampler == "euler_a" or args.sampler == "k_euler_a":
  1891. scheduler_cls = EulerAncestralDiscreteScheduler
  1892. scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete
  1893. elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++":
  1894. scheduler_cls = DPMSolverMultistepScheduler
  1895. sched_init_args["algorithm_type"] = args.sampler
  1896. scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep
  1897. elif args.sampler == "dpmsingle":
  1898. scheduler_cls = DPMSolverSinglestepScheduler
  1899. scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep
  1900. elif args.sampler == "heun":
  1901. scheduler_cls = HeunDiscreteScheduler
  1902. scheduler_module = diffusers.schedulers.scheduling_heun_discrete
  1903. elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2":
  1904. scheduler_cls = KDPM2DiscreteScheduler
  1905. scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete
  1906. elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a":
  1907. scheduler_cls = KDPM2AncestralDiscreteScheduler
  1908. scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete
  1909. scheduler_num_noises_per_step = 2
  1910. if args.v_parameterization:
  1911. sched_init_args["prediction_type"] = "v_prediction"
  1912. # samplerの乱数をあらかじめ指定するための処理
  1913. # replace randn
  1914. class NoiseManager:
  1915. def __init__(self):
  1916. self.sampler_noises = None
  1917. self.sampler_noise_index = 0
  1918. def reset_sampler_noises(self, noises):
  1919. self.sampler_noise_index = 0
  1920. self.sampler_noises = noises
  1921. def randn(self, shape, device=None, dtype=None, layout=None, generator=None):
  1922. # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index)
  1923. if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises):
  1924. noise = self.sampler_noises[self.sampler_noise_index]
  1925. if shape != noise.shape:
  1926. noise = None
  1927. else:
  1928. noise = None
  1929. if noise == None:
  1930. print(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
  1931. noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
  1932. self.sampler_noise_index += 1
  1933. return noise
  1934. class TorchRandReplacer:
  1935. def __init__(self, noise_manager):
  1936. self.noise_manager = noise_manager
  1937. def __getattr__(self, item):
  1938. if item == "randn":
  1939. return self.noise_manager.randn
  1940. if hasattr(torch, item):
  1941. return getattr(torch, item)
  1942. raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
  1943. noise_manager = NoiseManager()
  1944. if scheduler_module is not None:
  1945. scheduler_module.torch = TorchRandReplacer(noise_manager)
  1946. scheduler = scheduler_cls(
  1947. num_train_timesteps=SCHEDULER_TIMESTEPS,
  1948. beta_start=SCHEDULER_LINEAR_START,
  1949. beta_end=SCHEDULER_LINEAR_END,
  1950. beta_schedule=SCHEDLER_SCHEDULE,
  1951. **sched_init_args,
  1952. )
  1953. # clip_sample=Trueにする
  1954. if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
  1955. print("set clip_sample to True")
  1956. scheduler.config.clip_sample = True
  1957. # deviceを決定する
  1958. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
  1959. # custom pipelineをコピったやつを生成する
  1960. vae.to(dtype).to(device)
  1961. text_encoder.to(dtype).to(device)
  1962. unet.to(dtype).to(device)
  1963. if clip_model is not None:
  1964. clip_model.to(dtype).to(device)
  1965. if vgg16_model is not None:
  1966. vgg16_model.to(dtype).to(device)
  1967. # networkを組み込む
  1968. if args.network_module:
  1969. networks = []
  1970. network_default_muls = []
  1971. for i, network_module in enumerate(args.network_module):
  1972. print("import network module:", network_module)
  1973. imported_module = importlib.import_module(network_module)
  1974. network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
  1975. network_default_muls.append(network_mul)
  1976. net_kwargs = {}
  1977. if args.network_args and i < len(args.network_args):
  1978. network_args = args.network_args[i]
  1979. # TODO escape special chars
  1980. network_args = network_args.split(";")
  1981. for net_arg in network_args:
  1982. key, value = net_arg.split("=")
  1983. net_kwargs[key] = value
  1984. if args.network_weights and i < len(args.network_weights):
  1985. network_weight = args.network_weights[i]
  1986. print("load network weights from:", network_weight)
  1987. if model_util.is_safetensors(network_weight) and args.network_show_meta:
  1988. from safetensors.torch import safe_open
  1989. with safe_open(network_weight, framework="pt") as f:
  1990. metadata = f.metadata()
  1991. if metadata is not None:
  1992. print(f"metadata for: {network_weight}: {metadata}")
  1993. network, weights_sd = imported_module.create_network_from_weights(
  1994. network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
  1995. )
  1996. else:
  1997. raise ValueError("No weight. Weight is required.")
  1998. if network is None:
  1999. return
  2000. mergiable = hasattr(network, "merge_to")
  2001. if args.network_merge and not mergiable:
  2002. print("network is not mergiable. ignore merge option.")
  2003. if not args.network_merge or not mergiable:
  2004. network.apply_to(text_encoder, unet)
  2005. info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
  2006. print(f"weights are loaded: {info}")
  2007. if args.opt_channels_last:
  2008. network.to(memory_format=torch.channels_last)
  2009. network.to(dtype).to(device)
  2010. networks.append(network)
  2011. else:
  2012. network.merge_to(text_encoder, unet, weights_sd, dtype, device)
  2013. else:
  2014. networks = []
  2015. # upscalerの指定があれば取得する
  2016. upscaler = None
  2017. if args.highres_fix_upscaler:
  2018. print("import upscaler module:", args.highres_fix_upscaler)
  2019. imported_module = importlib.import_module(args.highres_fix_upscaler)
  2020. us_kwargs = {}
  2021. if args.highres_fix_upscaler_args:
  2022. for net_arg in args.highres_fix_upscaler_args.split(";"):
  2023. key, value = net_arg.split("=")
  2024. us_kwargs[key] = value
  2025. print("create upscaler")
  2026. upscaler = imported_module.create_upscaler(**us_kwargs)
  2027. upscaler.to(dtype).to(device)
  2028. # ControlNetの処理
  2029. control_nets: List[ControlNetInfo] = []
  2030. if args.control_net_models:
  2031. for i, model in enumerate(args.control_net_models):
  2032. prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
  2033. weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
  2034. ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
  2035. ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
  2036. prep = original_control_net.load_preprocess(prep_type)
  2037. control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
  2038. if args.opt_channels_last:
  2039. print(f"set optimizing: channels last")
  2040. text_encoder.to(memory_format=torch.channels_last)
  2041. vae.to(memory_format=torch.channels_last)
  2042. unet.to(memory_format=torch.channels_last)
  2043. if clip_model is not None:
  2044. clip_model.to(memory_format=torch.channels_last)
  2045. if networks:
  2046. for network in networks:
  2047. network.to(memory_format=torch.channels_last)
  2048. if vgg16_model is not None:
  2049. vgg16_model.to(memory_format=torch.channels_last)
  2050. for cn in control_nets:
  2051. cn.unet.to(memory_format=torch.channels_last)
  2052. cn.net.to(memory_format=torch.channels_last)
  2053. pipe = PipelineLike(
  2054. device,
  2055. vae,
  2056. text_encoder,
  2057. tokenizer,
  2058. unet,
  2059. scheduler,
  2060. args.clip_skip,
  2061. clip_model,
  2062. args.clip_guidance_scale,
  2063. args.clip_image_guidance_scale,
  2064. vgg16_model,
  2065. args.vgg16_guidance_scale,
  2066. args.vgg16_guidance_layer,
  2067. )
  2068. pipe.set_control_nets(control_nets)
  2069. print("pipeline is ready.")
  2070. if args.diffusers_xformers:
  2071. pipe.enable_xformers_memory_efficient_attention()
  2072. # Extended Textual Inversion および Textual Inversionを処理する
  2073. if args.XTI_embeddings:
  2074. diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
  2075. diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
  2076. diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
  2077. if args.textual_inversion_embeddings:
  2078. token_ids_embeds = []
  2079. for embeds_file in args.textual_inversion_embeddings:
  2080. if model_util.is_safetensors(embeds_file):
  2081. from safetensors.torch import load_file
  2082. data = load_file(embeds_file)
  2083. else:
  2084. data = torch.load(embeds_file, map_location="cpu")
  2085. if "string_to_param" in data:
  2086. data = data["string_to_param"]
  2087. embeds = next(iter(data.values()))
  2088. if type(embeds) != torch.Tensor:
  2089. raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}")
  2090. num_vectors_per_token = embeds.size()[0]
  2091. token_string = os.path.splitext(os.path.basename(embeds_file))[0]
  2092. token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
  2093. # add new word to tokenizer, count is num_vectors_per_token
  2094. num_added_tokens = tokenizer.add_tokens(token_strings)
  2095. assert (
  2096. num_added_tokens == num_vectors_per_token
  2097. ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
  2098. token_ids = tokenizer.convert_tokens_to_ids(token_strings)
  2099. print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
  2100. assert (
  2101. min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1
  2102. ), f"token ids is not ordered"
  2103. assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
  2104. if num_vectors_per_token > 1:
  2105. pipe.add_token_replacement(token_ids[0], token_ids)
  2106. token_ids_embeds.append((token_ids, embeds))
  2107. text_encoder.resize_token_embeddings(len(tokenizer))
  2108. token_embeds = text_encoder.get_input_embeddings().weight.data
  2109. for token_ids, embeds in token_ids_embeds:
  2110. for token_id, embed in zip(token_ids, embeds):
  2111. token_embeds[token_id] = embed
  2112. if args.XTI_embeddings:
  2113. XTI_layers = [
  2114. "IN01",
  2115. "IN02",
  2116. "IN04",
  2117. "IN05",
  2118. "IN07",
  2119. "IN08",
  2120. "MID",
  2121. "OUT03",
  2122. "OUT04",
  2123. "OUT05",
  2124. "OUT06",
  2125. "OUT07",
  2126. "OUT08",
  2127. "OUT09",
  2128. "OUT10",
  2129. "OUT11",
  2130. ]
  2131. token_ids_embeds_XTI = []
  2132. for embeds_file in args.XTI_embeddings:
  2133. if model_util.is_safetensors(embeds_file):
  2134. from safetensors.torch import load_file
  2135. data = load_file(embeds_file)
  2136. else:
  2137. data = torch.load(embeds_file, map_location="cpu")
  2138. if set(data.keys()) != set(XTI_layers):
  2139. raise ValueError("NOT XTI")
  2140. embeds = torch.concat(list(data.values()))
  2141. num_vectors_per_token = data["MID"].size()[0]
  2142. token_string = os.path.splitext(os.path.basename(embeds_file))[0]
  2143. token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
  2144. # add new word to tokenizer, count is num_vectors_per_token
  2145. num_added_tokens = tokenizer.add_tokens(token_strings)
  2146. assert (
  2147. num_added_tokens == num_vectors_per_token
  2148. ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
  2149. token_ids = tokenizer.convert_tokens_to_ids(token_strings)
  2150. print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
  2151. # if num_vectors_per_token > 1:
  2152. pipe.add_token_replacement(token_ids[0], token_ids)
  2153. token_strings_XTI = []
  2154. for layer_name in XTI_layers:
  2155. token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
  2156. tokenizer.add_tokens(token_strings_XTI)
  2157. token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
  2158. token_ids_embeds_XTI.append((token_ids_XTI, embeds))
  2159. for t in token_ids:
  2160. t_XTI_dic = {}
  2161. for i, layer_name in enumerate(XTI_layers):
  2162. t_XTI_dic[layer_name] = t + (i + 1) * num_added_tokens
  2163. pipe.add_token_replacement_XTI(t, t_XTI_dic)
  2164. text_encoder.resize_token_embeddings(len(tokenizer))
  2165. token_embeds = text_encoder.get_input_embeddings().weight.data
  2166. for token_ids, embeds in token_ids_embeds_XTI:
  2167. for token_id, embed in zip(token_ids, embeds):
  2168. token_embeds[token_id] = embed
  2169. # promptを取得する
  2170. if args.from_file is not None:
  2171. print(f"reading prompts from {args.from_file}")
  2172. with open(args.from_file, "r", encoding="utf-8") as f:
  2173. prompt_list = f.read().splitlines()
  2174. prompt_list = [d for d in prompt_list if len(d.strip()) > 0]
  2175. elif args.prompt is not None:
  2176. prompt_list = [args.prompt]
  2177. else:
  2178. prompt_list = []
  2179. if args.interactive:
  2180. args.n_iter = 1
  2181. # img2imgの前処理、画像の読み込みなど
  2182. def load_images(path):
  2183. if os.path.isfile(path):
  2184. paths = [path]
  2185. else:
  2186. paths = (
  2187. glob.glob(os.path.join(path, "*.png"))
  2188. + glob.glob(os.path.join(path, "*.jpg"))
  2189. + glob.glob(os.path.join(path, "*.jpeg"))
  2190. + glob.glob(os.path.join(path, "*.webp"))
  2191. )
  2192. paths.sort()
  2193. images = []
  2194. for p in paths:
  2195. image = Image.open(p)
  2196. if image.mode != "RGB":
  2197. print(f"convert image to RGB from {image.mode}: {p}")
  2198. image = image.convert("RGB")
  2199. images.append(image)
  2200. return images
  2201. def resize_images(imgs, size):
  2202. resized = []
  2203. for img in imgs:
  2204. r_img = img.resize(size, Image.Resampling.LANCZOS)
  2205. if hasattr(img, "filename"): # filename属性がない場合があるらしい
  2206. r_img.filename = img.filename
  2207. resized.append(r_img)
  2208. return resized
  2209. if args.image_path is not None:
  2210. print(f"load image for img2img: {args.image_path}")
  2211. init_images = load_images(args.image_path)
  2212. assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
  2213. print(f"loaded {len(init_images)} images for img2img")
  2214. else:
  2215. init_images = None
  2216. if args.mask_path is not None:
  2217. print(f"load mask for inpainting: {args.mask_path}")
  2218. mask_images = load_images(args.mask_path)
  2219. assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
  2220. print(f"loaded {len(mask_images)} mask images for inpainting")
  2221. else:
  2222. mask_images = None
  2223. # promptがないとき、画像のPngInfoから取得する
  2224. if init_images is not None and len(prompt_list) == 0 and not args.interactive:
  2225. print("get prompts from images' meta data")
  2226. for img in init_images:
  2227. if "prompt" in img.text:
  2228. prompt = img.text["prompt"]
  2229. if "negative-prompt" in img.text:
  2230. prompt += " --n " + img.text["negative-prompt"]
  2231. prompt_list.append(prompt)
  2232. # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する)
  2233. l = []
  2234. for im in init_images:
  2235. l.extend([im] * args.images_per_prompt)
  2236. init_images = l
  2237. if mask_images is not None:
  2238. l = []
  2239. for im in mask_images:
  2240. l.extend([im] * args.images_per_prompt)
  2241. mask_images = l
  2242. # 画像サイズにオプション指定があるときはリサイズする
  2243. if args.W is not None and args.H is not None:
  2244. if init_images is not None:
  2245. print(f"resize img2img source images to {args.W}*{args.H}")
  2246. init_images = resize_images(init_images, (args.W, args.H))
  2247. if mask_images is not None:
  2248. print(f"resize img2img mask images to {args.W}*{args.H}")
  2249. mask_images = resize_images(mask_images, (args.W, args.H))
  2250. regional_network = False
  2251. if networks and mask_images:
  2252. # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
  2253. regional_network = True
  2254. print("use mask as region")
  2255. size = None
  2256. for i, network in enumerate(networks):
  2257. if i < 3:
  2258. np_mask = np.array(mask_images[0])
  2259. np_mask = np_mask[:, :, i]
  2260. size = np_mask.shape
  2261. else:
  2262. np_mask = np.full(size, 255, dtype=np.uint8)
  2263. mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
  2264. network.set_region(i, i == len(networks) - 1, mask)
  2265. mask_images = None
  2266. prev_image = None # for VGG16 guided
  2267. if args.guide_image_path is not None:
  2268. print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
  2269. guide_images = []
  2270. for p in args.guide_image_path:
  2271. guide_images.extend(load_images(p))
  2272. print(f"loaded {len(guide_images)} guide images for guidance")
  2273. if len(guide_images) == 0:
  2274. print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
  2275. guide_images = None
  2276. else:
  2277. guide_images = None
  2278. # seed指定時はseedを決めておく
  2279. if args.seed is not None:
  2280. random.seed(args.seed)
  2281. predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)]
  2282. if len(predefined_seeds) == 1:
  2283. predefined_seeds[0] = args.seed
  2284. else:
  2285. predefined_seeds = None
  2286. # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み)
  2287. if args.W is None:
  2288. args.W = 512
  2289. if args.H is None:
  2290. args.H = 512
  2291. # 画像生成のループ
  2292. os.makedirs(args.outdir, exist_ok=True)
  2293. max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
  2294. for gen_iter in range(args.n_iter):
  2295. print(f"iteration {gen_iter+1}/{args.n_iter}")
  2296. iter_seed = random.randint(0, 0x7FFFFFFF)
  2297. # バッチ処理の関数
  2298. def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
  2299. batch_size = len(batch)
  2300. # highres_fixの処理
  2301. if highres_fix and not highres_1st:
  2302. # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
  2303. is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
  2304. print("process 1st stage")
  2305. batch_1st = []
  2306. for _, base, ext in batch:
  2307. width_1st = int(ext.width * args.highres_fix_scale + 0.5)
  2308. height_1st = int(ext.height * args.highres_fix_scale + 0.5)
  2309. width_1st = width_1st - width_1st % 32
  2310. height_1st = height_1st - height_1st % 32
  2311. ext_1st = BatchDataExt(
  2312. width_1st,
  2313. height_1st,
  2314. args.highres_fix_steps,
  2315. ext.scale,
  2316. ext.negative_scale,
  2317. ext.strength,
  2318. ext.network_muls,
  2319. ext.num_sub_prompts,
  2320. )
  2321. batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
  2322. images_1st = process_batch(batch_1st, True, True)
  2323. # 2nd stageのバッチを作成して以下処理する
  2324. print("process 2nd stage")
  2325. width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
  2326. if upscaler:
  2327. # upscalerを使って画像を拡大する
  2328. lowreso_imgs = None if is_1st_latent else images_1st
  2329. lowreso_latents = None if not is_1st_latent else images_1st
  2330. # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents
  2331. batch_size = len(images_1st)
  2332. vae_batch_size = (
  2333. batch_size
  2334. if args.vae_batch_size is None
  2335. else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size)
  2336. )
  2337. vae_batch_size = int(vae_batch_size)
  2338. images_1st = upscaler.upscale(
  2339. vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size
  2340. )
  2341. elif args.highres_fix_latents_upscaling:
  2342. # latentを拡大する
  2343. org_dtype = images_1st.dtype
  2344. if images_1st.dtype == torch.bfloat16:
  2345. images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
  2346. images_1st = torch.nn.functional.interpolate(
  2347. images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear"
  2348. ) # , antialias=True)
  2349. images_1st = images_1st.to(org_dtype)
  2350. else:
  2351. # 画像をLANCZOSで拡大する
  2352. images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st]
  2353. batch_2nd = []
  2354. for i, (bd, image) in enumerate(zip(batch, images_1st)):
  2355. bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext)
  2356. batch_2nd.append(bd_2nd)
  2357. batch = batch_2nd
  2358. # このバッチの情報を取り出す
  2359. (
  2360. return_latents,
  2361. (step_first, _, _, _, init_image, mask_image, _, guide_image),
  2362. (width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts),
  2363. ) = batch[0]
  2364. noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
  2365. prompts = []
  2366. negative_prompts = []
  2367. start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
  2368. noises = [
  2369. torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
  2370. for _ in range(steps * scheduler_num_noises_per_step)
  2371. ]
  2372. seeds = []
  2373. clip_prompts = []
  2374. if init_image is not None: # img2img?
  2375. i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
  2376. init_images = []
  2377. if mask_image is not None:
  2378. mask_images = []
  2379. else:
  2380. mask_images = None
  2381. else:
  2382. i2i_noises = None
  2383. init_images = None
  2384. mask_images = None
  2385. if guide_image is not None: # CLIP image guided?
  2386. guide_images = []
  2387. else:
  2388. guide_images = None
  2389. # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする
  2390. all_images_are_same = True
  2391. all_masks_are_same = True
  2392. all_guide_images_are_same = True
  2393. for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
  2394. prompts.append(prompt)
  2395. negative_prompts.append(negative_prompt)
  2396. seeds.append(seed)
  2397. clip_prompts.append(clip_prompt)
  2398. if init_image is not None:
  2399. init_images.append(init_image)
  2400. if i > 0 and all_images_are_same:
  2401. all_images_are_same = init_images[-2] is init_image
  2402. if mask_image is not None:
  2403. mask_images.append(mask_image)
  2404. if i > 0 and all_masks_are_same:
  2405. all_masks_are_same = mask_images[-2] is mask_image
  2406. if guide_image is not None:
  2407. if type(guide_image) is list:
  2408. guide_images.extend(guide_image)
  2409. all_guide_images_are_same = False
  2410. else:
  2411. guide_images.append(guide_image)
  2412. if i > 0 and all_guide_images_are_same:
  2413. all_guide_images_are_same = guide_images[-2] is guide_image
  2414. # make start code
  2415. torch.manual_seed(seed)
  2416. start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype)
  2417. # make each noises
  2418. for j in range(steps * scheduler_num_noises_per_step):
  2419. noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype)
  2420. if i2i_noises is not None: # img2img noise
  2421. i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype)
  2422. noise_manager.reset_sampler_noises(noises)
  2423. # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する
  2424. if init_images is not None and all_images_are_same:
  2425. init_images = init_images[0]
  2426. if mask_images is not None and all_masks_are_same:
  2427. mask_images = mask_images[0]
  2428. if guide_images is not None and all_guide_images_are_same:
  2429. guide_images = guide_images[0]
  2430. # ControlNet使用時はguide imageをリサイズする
  2431. if control_nets:
  2432. # TODO resampleのメソッド
  2433. guide_images = guide_images if type(guide_images) == list else [guide_images]
  2434. guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
  2435. if len(guide_images) == 1:
  2436. guide_images = guide_images[0]
  2437. # generate
  2438. if networks:
  2439. shared = {}
  2440. for n, m in zip(networks, network_muls if network_muls else network_default_muls):
  2441. n.set_multiplier(m)
  2442. if regional_network:
  2443. n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
  2444. images = pipe(
  2445. prompts,
  2446. negative_prompts,
  2447. init_images,
  2448. mask_images,
  2449. height,
  2450. width,
  2451. steps,
  2452. scale,
  2453. negative_scale,
  2454. strength,
  2455. latents=start_code,
  2456. output_type="pil",
  2457. max_embeddings_multiples=max_embeddings_multiples,
  2458. img2img_noise=i2i_noises,
  2459. vae_batch_size=args.vae_batch_size,
  2460. return_latents=return_latents,
  2461. clip_prompts=clip_prompts,
  2462. clip_guide_images=guide_images,
  2463. )[0]
  2464. if highres_1st and not args.highres_fix_save_1st: # return images or latents
  2465. return images
  2466. # save image
  2467. highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
  2468. ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
  2469. for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate(
  2470. zip(images, prompts, negative_prompts, seeds, clip_prompts)
  2471. ):
  2472. metadata = PngInfo()
  2473. metadata.add_text("prompt", prompt)
  2474. metadata.add_text("seed", str(seed))
  2475. metadata.add_text("sampler", args.sampler)
  2476. metadata.add_text("steps", str(steps))
  2477. metadata.add_text("scale", str(scale))
  2478. if negative_prompt is not None:
  2479. metadata.add_text("negative-prompt", negative_prompt)
  2480. if negative_scale is not None:
  2481. metadata.add_text("negative-scale", str(negative_scale))
  2482. if clip_prompt is not None:
  2483. metadata.add_text("clip-prompt", clip_prompt)
  2484. if args.use_original_file_name and init_images is not None:
  2485. if type(init_images) is list:
  2486. fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
  2487. else:
  2488. fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
  2489. elif args.sequential_file_name:
  2490. fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
  2491. else:
  2492. fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
  2493. image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
  2494. if not args.no_preview and not highres_1st and args.interactive:
  2495. try:
  2496. import cv2
  2497. for prompt, image in zip(prompts, images):
  2498. cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ
  2499. cv2.waitKey()
  2500. cv2.destroyAllWindows()
  2501. except ImportError:
  2502. print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません")
  2503. return images
  2504. # 画像生成のプロンプトが一周するまでのループ
  2505. prompt_index = 0
  2506. global_step = 0
  2507. batch_data = []
  2508. while args.interactive or prompt_index < len(prompt_list):
  2509. if len(prompt_list) == 0:
  2510. # interactive
  2511. valid = False
  2512. while not valid:
  2513. print("\nType prompt:")
  2514. try:
  2515. prompt = input()
  2516. except EOFError:
  2517. break
  2518. valid = len(prompt.strip().split(" --")[0].strip()) > 0
  2519. if not valid: # EOF, end app
  2520. break
  2521. else:
  2522. prompt = prompt_list[prompt_index]
  2523. # parse prompt
  2524. width = args.W
  2525. height = args.H
  2526. scale = args.scale
  2527. negative_scale = args.negative_scale
  2528. steps = args.steps
  2529. seeds = None
  2530. strength = 0.8 if args.strength is None else args.strength
  2531. negative_prompt = ""
  2532. clip_prompt = None
  2533. network_muls = None
  2534. prompt_args = prompt.strip().split(" --")
  2535. prompt = prompt_args[0]
  2536. print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
  2537. for parg in prompt_args[1:]:
  2538. try:
  2539. m = re.match(r"w (\d+)", parg, re.IGNORECASE)
  2540. if m:
  2541. width = int(m.group(1))
  2542. print(f"width: {width}")
  2543. continue
  2544. m = re.match(r"h (\d+)", parg, re.IGNORECASE)
  2545. if m:
  2546. height = int(m.group(1))
  2547. print(f"height: {height}")
  2548. continue
  2549. m = re.match(r"s (\d+)", parg, re.IGNORECASE)
  2550. if m: # steps
  2551. steps = max(1, min(1000, int(m.group(1))))
  2552. print(f"steps: {steps}")
  2553. continue
  2554. m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
  2555. if m: # seed
  2556. seeds = [int(d) for d in m.group(1).split(",")]
  2557. print(f"seeds: {seeds}")
  2558. continue
  2559. m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
  2560. if m: # scale
  2561. scale = float(m.group(1))
  2562. print(f"scale: {scale}")
  2563. continue
  2564. m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
  2565. if m: # negative scale
  2566. if m.group(1).lower() == "none":
  2567. negative_scale = None
  2568. else:
  2569. negative_scale = float(m.group(1))
  2570. print(f"negative scale: {negative_scale}")
  2571. continue
  2572. m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
  2573. if m: # strength
  2574. strength = float(m.group(1))
  2575. print(f"strength: {strength}")
  2576. continue
  2577. m = re.match(r"n (.+)", parg, re.IGNORECASE)
  2578. if m: # negative prompt
  2579. negative_prompt = m.group(1)
  2580. print(f"negative prompt: {negative_prompt}")
  2581. continue
  2582. m = re.match(r"c (.+)", parg, re.IGNORECASE)
  2583. if m: # clip prompt
  2584. clip_prompt = m.group(1)
  2585. print(f"clip prompt: {clip_prompt}")
  2586. continue
  2587. m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
  2588. if m: # network multiplies
  2589. network_muls = [float(v) for v in m.group(1).split(",")]
  2590. while len(network_muls) < len(networks):
  2591. network_muls.append(network_muls[-1])
  2592. print(f"network mul: {network_muls}")
  2593. continue
  2594. except ValueError as ex:
  2595. print(f"Exception in parsing / 解析エラー: {parg}")
  2596. print(ex)
  2597. if seeds is not None:
  2598. # 数が足りないなら繰り返す
  2599. if len(seeds) < args.images_per_prompt:
  2600. seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds)))
  2601. seeds = seeds[: args.images_per_prompt]
  2602. else:
  2603. if predefined_seeds is not None:
  2604. seeds = predefined_seeds[-args.images_per_prompt :]
  2605. predefined_seeds = predefined_seeds[: -args.images_per_prompt]
  2606. elif args.iter_same_seed:
  2607. seeds = [iter_seed] * args.images_per_prompt
  2608. else:
  2609. seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.images_per_prompt)]
  2610. if args.interactive:
  2611. print(f"seed: {seeds}")
  2612. init_image = mask_image = guide_image = None
  2613. for seed in seeds: # images_per_promptの数だけ
  2614. # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する
  2615. if init_images is not None:
  2616. init_image = init_images[global_step % len(init_images)]
  2617. # 32単位に丸めたやつにresizeされるので踏襲する
  2618. width, height = init_image.size
  2619. width = width - width % 32
  2620. height = height - height % 32
  2621. if width != init_image.size[0] or height != init_image.size[1]:
  2622. print(
  2623. f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
  2624. )
  2625. if mask_images is not None:
  2626. mask_image = mask_images[global_step % len(mask_images)]
  2627. if guide_images is not None:
  2628. if control_nets: # 複数件の場合あり
  2629. c = len(control_nets)
  2630. p = global_step % (len(guide_images) // c)
  2631. guide_image = guide_images[p * c : p * c + c]
  2632. else:
  2633. guide_image = guide_images[global_step % len(guide_images)]
  2634. elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
  2635. if prev_image is None:
  2636. print("Generate 1st image without guide image.")
  2637. else:
  2638. print("Use previous image as guide image.")
  2639. guide_image = prev_image
  2640. if regional_network:
  2641. num_sub_prompts = len(prompt.split(" AND "))
  2642. assert (
  2643. len(networks) <= num_sub_prompts
  2644. ), "Number of networks must be less than or equal to number of sub prompts."
  2645. else:
  2646. num_sub_prompts = None
  2647. b1 = BatchData(
  2648. False,
  2649. BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
  2650. BatchDataExt(
  2651. width,
  2652. height,
  2653. steps,
  2654. scale,
  2655. negative_scale,
  2656. strength,
  2657. tuple(network_muls) if network_muls else None,
  2658. num_sub_prompts,
  2659. ),
  2660. )
  2661. if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
  2662. process_batch(batch_data, highres_fix)
  2663. batch_data.clear()
  2664. batch_data.append(b1)
  2665. if len(batch_data) == args.batch_size:
  2666. prev_image = process_batch(batch_data, highres_fix)[0]
  2667. batch_data.clear()
  2668. global_step += 1
  2669. prompt_index += 1
  2670. if len(batch_data) > 0:
  2671. process_batch(batch_data, highres_fix)
  2672. batch_data.clear()
  2673. print("done!")
  2674. def setup_parser() -> argparse.ArgumentParser:
  2675. parser = argparse.ArgumentParser()
  2676. parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む")
  2677. parser.add_argument(
  2678. "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする"
  2679. )
  2680. parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト")
  2681. parser.add_argument(
  2682. "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む"
  2683. )
  2684. parser.add_argument(
  2685. "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)"
  2686. )
  2687. parser.add_argument(
  2688. "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない"
  2689. )
  2690. parser.add_argument(
  2691. "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像"
  2692. )
  2693. parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク")
  2694. parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength")
  2695. parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数")
  2696. parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先")
  2697. parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする")
  2698. parser.add_argument(
  2699. "--use_original_file_name",
  2700. action="store_true",
  2701. help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける",
  2702. )
  2703. # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", )
  2704. parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数")
  2705. parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
  2706. parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
  2707. parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
  2708. parser.add_argument(
  2709. "--vae_batch_size",
  2710. type=float,
  2711. default=None,
  2712. help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率",
  2713. )
  2714. parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
  2715. parser.add_argument(
  2716. "--sampler",
  2717. type=str,
  2718. default="ddim",
  2719. choices=[
  2720. "ddim",
  2721. "pndm",
  2722. "lms",
  2723. "euler",
  2724. "euler_a",
  2725. "heun",
  2726. "dpm_2",
  2727. "dpm_2_a",
  2728. "dpmsolver",
  2729. "dpmsolver++",
  2730. "dpmsingle",
  2731. "k_lms",
  2732. "k_euler",
  2733. "k_euler_a",
  2734. "k_dpm_2",
  2735. "k_dpm_2_a",
  2736. ],
  2737. help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類",
  2738. )
  2739. parser.add_argument(
  2740. "--scale",
  2741. type=float,
  2742. default=7.5,
  2743. help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale",
  2744. )
  2745. parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
  2746. parser.add_argument(
  2747. "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
  2748. )
  2749. parser.add_argument(
  2750. "--tokenizer_cache_dir",
  2751. type=str,
  2752. default=None,
  2753. help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
  2754. )
  2755. # parser.add_argument("--replace_clip_l14_336", action='store_true',
  2756. # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
  2757. parser.add_argument(
  2758. "--seed",
  2759. type=int,
  2760. default=None,
  2761. help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed",
  2762. )
  2763. parser.add_argument(
  2764. "--iter_same_seed",
  2765. action="store_true",
  2766. help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)",
  2767. )
  2768. parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する")
  2769. parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する")
  2770. parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する")
  2771. parser.add_argument(
  2772. "--diffusers_xformers",
  2773. action="store_true",
  2774. help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)",
  2775. )
  2776. parser.add_argument(
  2777. "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する"
  2778. )
  2779. parser.add_argument(
  2780. "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名"
  2781. )
  2782. parser.add_argument(
  2783. "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み"
  2784. )
  2785. parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率")
  2786. parser.add_argument(
  2787. "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
  2788. )
  2789. parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
  2790. parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
  2791. parser.add_argument(
  2792. "--textual_inversion_embeddings",
  2793. type=str,
  2794. default=None,
  2795. nargs="*",
  2796. help="Embeddings files of Textual Inversion / Textual Inversionのembeddings",
  2797. )
  2798. parser.add_argument(
  2799. "--XTI_embeddings",
  2800. type=str,
  2801. default=None,
  2802. nargs="*",
  2803. help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings",
  2804. )
  2805. parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う")
  2806. parser.add_argument(
  2807. "--max_embeddings_multiples",
  2808. type=int,
  2809. default=None,
  2810. help="max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
  2811. )
  2812. parser.add_argument(
  2813. "--clip_guidance_scale",
  2814. type=float,
  2815. default=0.0,
  2816. help="enable CLIP guided SD, scale for guidance (DDIM, PNDM, LMS samplers only) / CLIP guided SDを有効にしてこのscaleを適用する(サンプラーはDDIM、PNDM、LMSのみ)",
  2817. )
  2818. parser.add_argument(
  2819. "--clip_image_guidance_scale",
  2820. type=float,
  2821. default=0.0,
  2822. help="enable CLIP guided SD by image, scale for guidance / 画像によるCLIP guided SDを有効にしてこのscaleを適用する",
  2823. )
  2824. parser.add_argument(
  2825. "--vgg16_guidance_scale",
  2826. type=float,
  2827. default=0.0,
  2828. help="enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する",
  2829. )
  2830. parser.add_argument(
  2831. "--vgg16_guidance_layer",
  2832. type=int,
  2833. default=20,
  2834. help="layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)",
  2835. )
  2836. parser.add_argument(
  2837. "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像"
  2838. )
  2839. parser.add_argument(
  2840. "--highres_fix_scale",
  2841. type=float,
  2842. default=None,
  2843. help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする",
  2844. )
  2845. parser.add_argument(
  2846. "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数"
  2847. )
  2848. parser.add_argument(
  2849. "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する"
  2850. )
  2851. parser.add_argument(
  2852. "--highres_fix_latents_upscaling",
  2853. action="store_true",
  2854. help="use latents upscaling for highres fix / highres fixでlatentで拡大する",
  2855. )
  2856. parser.add_argument(
  2857. "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名"
  2858. )
  2859. parser.add_argument(
  2860. "--highres_fix_upscaler_args",
  2861. type=str,
  2862. default=None,
  2863. help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
  2864. )
  2865. parser.add_argument(
  2866. "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"
  2867. )
  2868. parser.add_argument(
  2869. "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名"
  2870. )
  2871. parser.add_argument(
  2872. "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名"
  2873. )
  2874. parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み")
  2875. parser.add_argument(
  2876. "--control_net_ratios",
  2877. type=float,
  2878. default=None,
  2879. nargs="*",
  2880. help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
  2881. )
  2882. # parser.add_argument(
  2883. # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
  2884. # )
  2885. return parser
  2886. if __name__ == "__main__":
  2887. parser = setup_parser()
  2888. args = parser.parse_args()
  2889. main(args)