ui.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. import os
  2. import json
  3. import gradio as gr
  4. from collections import OrderedDict
  5. from pathlib import Path
  6. from glob import glob
  7. from PIL import Image, UnidentifiedImageError
  8. from webui import wrap_gradio_gpu_call
  9. from modules import ui
  10. from modules import generation_parameters_copypaste as parameters_copypaste
  11. from tagger import format, utils
  12. from tagger.utils import split_str
  13. from tagger.interrogator import Interrogator
  14. def unload_interrogators():
  15. unloaded_models = 0
  16. for i in utils.interrogators.values():
  17. if i.unload():
  18. unloaded_models = unloaded_models + 1
  19. return [f'Successfully unload {unloaded_models} model(s)']
  20. def on_interrogate(
  21. image: Image,
  22. batch_input_glob: str,
  23. batch_input_recursive: bool,
  24. batch_output_dir: str,
  25. batch_output_filename_format: str,
  26. batch_output_action_on_conflict: str,
  27. batch_remove_duplicated_tag: bool,
  28. batch_output_save_json: bool,
  29. interrogator: str,
  30. threshold: float,
  31. additional_tags: str,
  32. exclude_tags: str,
  33. sort_by_alphabetical_order: bool,
  34. add_confident_as_weight: bool,
  35. replace_underscore: bool,
  36. replace_underscore_excludes: str,
  37. escape_tag: bool,
  38. unload_model_after_running: bool
  39. ):
  40. if interrogator not in utils.interrogators:
  41. return ['', None, None, f"'{interrogator}' is not a valid interrogator"]
  42. interrogator: Interrogator = utils.interrogators[interrogator]
  43. postprocess_opts = (
  44. threshold,
  45. split_str(additional_tags),
  46. split_str(exclude_tags),
  47. sort_by_alphabetical_order,
  48. add_confident_as_weight,
  49. replace_underscore,
  50. split_str(replace_underscore_excludes),
  51. escape_tag
  52. )
  53. # single process
  54. if image is not None:
  55. ratings, tags = interrogator.interrogate(image)
  56. processed_tags = Interrogator.postprocess_tags(
  57. tags,
  58. *postprocess_opts
  59. )
  60. if unload_model_after_running:
  61. interrogator.unload()
  62. return [
  63. ', '.join(processed_tags),
  64. ratings,
  65. tags,
  66. ''
  67. ]
  68. # batch process
  69. batch_input_glob = batch_input_glob.strip()
  70. batch_output_dir = batch_output_dir.strip()
  71. batch_output_filename_format = batch_output_filename_format.strip()
  72. if batch_input_glob != '':
  73. # if there is no glob pattern, insert it automatically
  74. if not batch_input_glob.endswith('*'):
  75. if not batch_input_glob.endswith(os.sep):
  76. batch_input_glob += os.sep
  77. batch_input_glob += '*'
  78. # get root directory of input glob pattern
  79. base_dir = batch_input_glob.replace('?', '*')
  80. base_dir = base_dir.split(os.sep + '*').pop(0)
  81. # check the input directory path
  82. if not os.path.isdir(base_dir):
  83. return ['', None, None, 'input path is not a directory']
  84. # this line is moved here because some reason
  85. # PIL.Image.registered_extensions() returns only PNG if you call too early
  86. supported_extensions = [
  87. e
  88. for e, f in Image.registered_extensions().items()
  89. if f in Image.OPEN
  90. ]
  91. paths = [
  92. Path(p)
  93. for p in glob(batch_input_glob, recursive=batch_input_recursive)
  94. if '.' + p.split('.').pop().lower() in supported_extensions
  95. ]
  96. print(f'found {len(paths)} image(s)')
  97. for path in paths:
  98. try:
  99. image = Image.open(path)
  100. except UnidentifiedImageError:
  101. # just in case, user has mysterious file...
  102. print(f'${path} is not supported image type')
  103. continue
  104. # guess the output path
  105. base_dir_last = Path(base_dir).parts[-1]
  106. base_dir_last_idx = path.parts.index(base_dir_last)
  107. output_dir = Path(
  108. batch_output_dir) if batch_output_dir else Path(base_dir)
  109. output_dir = output_dir.joinpath(
  110. *path.parts[base_dir_last_idx + 1:]).parent
  111. output_dir.mkdir(0o777, True, True)
  112. # format output filename
  113. format_info = format.Info(path, 'txt')
  114. try:
  115. formatted_output_filename = format.pattern.sub(
  116. lambda m: format.format(m, format_info),
  117. batch_output_filename_format
  118. )
  119. except (TypeError, ValueError) as error:
  120. return ['', None, None, str(error)]
  121. output_path = output_dir.joinpath(
  122. formatted_output_filename
  123. )
  124. output = []
  125. if output_path.is_file():
  126. output.append(output_path.read_text(errors='ignore').strip())
  127. if batch_output_action_on_conflict == 'ignore':
  128. print(f'skipping {path}')
  129. continue
  130. ratings, tags = interrogator.interrogate(image)
  131. processed_tags = Interrogator.postprocess_tags(
  132. tags,
  133. *postprocess_opts
  134. )
  135. # TODO: switch for less print
  136. print(
  137. f'found {len(processed_tags)} tags out of {len(tags)} from {path}'
  138. )
  139. plain_tags = ', '.join(processed_tags)
  140. if batch_output_action_on_conflict == 'copy':
  141. output = [plain_tags]
  142. elif batch_output_action_on_conflict == 'prepend':
  143. output.insert(0, plain_tags)
  144. else:
  145. output.append(plain_tags)
  146. if batch_remove_duplicated_tag:
  147. output_path.write_text(
  148. ', '.join(
  149. OrderedDict.fromkeys(
  150. map(str.strip, ','.join(output).split(','))
  151. )
  152. ),
  153. encoding='utf-8'
  154. )
  155. else:
  156. output_path.write_text(
  157. ', '.join(output),
  158. encoding='utf-8'
  159. )
  160. if batch_output_save_json:
  161. output_path.with_suffix('.json').write_text(
  162. json.dumps([ratings, tags])
  163. )
  164. print('all done :)')
  165. if unload_model_after_running:
  166. interrogator.unload()
  167. return ['', None, None, '']
  168. def on_ui_tabs():
  169. with gr.Blocks(analytics_enabled=False) as tagger_interface:
  170. with gr.Row().style(equal_height=False):
  171. with gr.Column(variant='panel'):
  172. # input components
  173. with gr.Tabs():
  174. with gr.TabItem(label='Single process'):
  175. image = gr.Image(
  176. label='Source',
  177. source='upload',
  178. interactive=True,
  179. type="pil"
  180. )
  181. with gr.TabItem(label='Batch from directory'):
  182. batch_input_glob = utils.preset.component(
  183. gr.Textbox,
  184. label='Input directory',
  185. placeholder='/path/to/images or /path/to/images/**/*'
  186. )
  187. batch_input_recursive = utils.preset.component(
  188. gr.Checkbox,
  189. label='Use recursive with glob pattern'
  190. )
  191. batch_output_dir = utils.preset.component(
  192. gr.Textbox,
  193. label='Output directory',
  194. placeholder='Leave blank to save images to the same path.'
  195. )
  196. batch_output_filename_format = utils.preset.component(
  197. gr.Textbox,
  198. label='Output filename format',
  199. placeholder='Leave blank to use same filename as original.',
  200. value='[name].[output_extension]'
  201. )
  202. import hashlib
  203. with gr.Accordion(
  204. label='Output filename formats',
  205. open=False
  206. ):
  207. gr.Markdown(
  208. value=f'''
  209. ### Related to original file
  210. - `[name]`: Original filename without extension
  211. - `[extension]`: Original extension
  212. - `[hash:<algorithms>]`: Original extension
  213. Available algorithms: `{', '.join(hashlib.algorithms_available)}`
  214. ### Related to output file
  215. - `[output_extension]`: Output extension (has no dot)
  216. ## Examples
  217. ### Original filename without extension
  218. `[name].[output_extension]`
  219. ### Original file's hash (good for deleting duplication)
  220. `[hash:sha1].[output_extension]`
  221. '''
  222. )
  223. batch_output_action_on_conflict = utils.preset.component(
  224. gr.Dropdown,
  225. label='Action on existing caption',
  226. value='ignore',
  227. choices=[
  228. 'ignore',
  229. 'copy',
  230. 'append',
  231. 'prepend'
  232. ]
  233. )
  234. batch_remove_duplicated_tag = utils.preset.component(
  235. gr.Checkbox,
  236. label='Remove duplicated tag'
  237. )
  238. batch_output_save_json = utils.preset.component(
  239. gr.Checkbox,
  240. label='Save with JSON'
  241. )
  242. submit = gr.Button(
  243. value='Interrogate',
  244. variant='primary'
  245. )
  246. info = gr.HTML()
  247. # preset selector
  248. with gr.Row(variant='compact'):
  249. available_presets = utils.preset.list()
  250. selected_preset = gr.Dropdown(
  251. label='Preset',
  252. choices=available_presets,
  253. value=available_presets[0]
  254. )
  255. save_preset_button = gr.Button(
  256. value=ui.save_style_symbol
  257. )
  258. ui.create_refresh_button(
  259. selected_preset,
  260. lambda: None,
  261. lambda: {'choices': utils.preset.list()},
  262. 'refresh_preset'
  263. )
  264. # option components
  265. # interrogator selector
  266. with gr.Column():
  267. with gr.Row(variant='compact'):
  268. interrogator_names = utils.refresh_interrogators()
  269. interrogator = utils.preset.component(
  270. gr.Dropdown,
  271. label='Interrogator',
  272. choices=interrogator_names,
  273. value=(
  274. None
  275. if len(interrogator_names) < 1 else
  276. interrogator_names[-1]
  277. )
  278. )
  279. ui.create_refresh_button(
  280. interrogator,
  281. lambda: None,
  282. lambda: {'choices': utils.refresh_interrogators()},
  283. 'refresh_interrogator'
  284. )
  285. unload_all_models = gr.Button(
  286. value='Unload all interrogate models'
  287. )
  288. threshold = utils.preset.component(
  289. gr.Slider,
  290. label='Threshold',
  291. minimum=0,
  292. maximum=1,
  293. value=0.35
  294. )
  295. additional_tags = utils.preset.component(
  296. gr.Textbox,
  297. label='Additional tags (split by comma)',
  298. elem_id='additioanl-tags'
  299. )
  300. exclude_tags = utils.preset.component(
  301. gr.Textbox,
  302. label='Exclude tags (split by comma)',
  303. elem_id='exclude-tags'
  304. )
  305. sort_by_alphabetical_order = utils.preset.component(
  306. gr.Checkbox,
  307. label='Sort by alphabetical order',
  308. )
  309. add_confident_as_weight = utils.preset.component(
  310. gr.Checkbox,
  311. label='Include confident of tags matches in results'
  312. )
  313. replace_underscore = utils.preset.component(
  314. gr.Checkbox,
  315. label='Use spaces instead of underscore',
  316. value=True
  317. )
  318. replace_underscore_excludes = utils.preset.component(
  319. gr.Textbox,
  320. label='Excudes (split by comma)',
  321. # kaomoji from WD 1.4 tagger csv. thanks, Meow-San#5400!
  322. value='0_0, (o)_(o), +_+, +_-, ._., <o>_<o>, <|>_<|>, =_=, >_<, 3_3, 6_9, >_o, @_@, ^_^, o_o, u_u, x_x, |_|, ||_||'
  323. )
  324. escape_tag = utils.preset.component(
  325. gr.Checkbox,
  326. label='Escape brackets',
  327. )
  328. unload_model_after_running = utils.preset.component(
  329. gr.Checkbox,
  330. label='Unload model after running',
  331. )
  332. # output components
  333. with gr.Column(variant='panel'):
  334. tags = gr.Textbox(
  335. label='Tags',
  336. placeholder='Found tags',
  337. interactive=False
  338. )
  339. with gr.Row():
  340. parameters_copypaste.bind_buttons(
  341. parameters_copypaste.create_buttons(
  342. ["txt2img", "img2img"],
  343. ),
  344. None,
  345. tags
  346. )
  347. rating_confidents = gr.Label(
  348. label='Rating confidents',
  349. elem_id='rating-confidents'
  350. )
  351. tag_confidents = gr.Label(
  352. label='Tag confidents',
  353. elem_id='tag-confidents'
  354. )
  355. # register events
  356. selected_preset.change(
  357. fn=utils.preset.apply,
  358. inputs=[selected_preset],
  359. outputs=[*utils.preset.components, info]
  360. )
  361. save_preset_button.click(
  362. fn=utils.preset.save,
  363. inputs=[selected_preset, *utils.preset.components], # values only
  364. outputs=[info]
  365. )
  366. unload_all_models.click(
  367. fn=unload_interrogators,
  368. outputs=[info]
  369. )
  370. for func in [image.change, submit.click]:
  371. func(
  372. fn=wrap_gradio_gpu_call(on_interrogate),
  373. inputs=[
  374. # single process
  375. image,
  376. # batch process
  377. batch_input_glob,
  378. batch_input_recursive,
  379. batch_output_dir,
  380. batch_output_filename_format,
  381. batch_output_action_on_conflict,
  382. batch_remove_duplicated_tag,
  383. batch_output_save_json,
  384. # options
  385. interrogator,
  386. threshold,
  387. additional_tags,
  388. exclude_tags,
  389. sort_by_alphabetical_order,
  390. add_confident_as_weight,
  391. replace_underscore,
  392. replace_underscore_excludes,
  393. escape_tag,
  394. unload_model_after_running
  395. ],
  396. outputs=[
  397. tags,
  398. rating_confidents,
  399. tag_confidents,
  400. # contains execution time, memory usage and other stuffs...
  401. # generated from modules.ui.wrap_gradio_call
  402. info
  403. ]
  404. )
  405. return [(tagger_interface, "Tagger", "tagger")]