sd_hijack_clip_old.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from modules import sd_hijack_clip
  2. from modules import shared
  3. def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
  4. id_start = self.id_start
  5. id_end = self.id_end
  6. maxlen = self.wrapped.max_length # you get to stay at 77
  7. used_custom_terms = []
  8. remade_batch_tokens = []
  9. hijack_comments = []
  10. hijack_fixes = []
  11. token_count = 0
  12. cache = {}
  13. batch_tokens = self.tokenize(texts)
  14. batch_multipliers = []
  15. for tokens in batch_tokens:
  16. tuple_tokens = tuple(tokens)
  17. if tuple_tokens in cache:
  18. remade_tokens, fixes, multipliers = cache[tuple_tokens]
  19. else:
  20. fixes = []
  21. remade_tokens = []
  22. multipliers = []
  23. mult = 1.0
  24. i = 0
  25. while i < len(tokens):
  26. token = tokens[i]
  27. embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
  28. mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
  29. if mult_change is not None:
  30. mult *= mult_change
  31. i += 1
  32. elif embedding is None:
  33. remade_tokens.append(token)
  34. multipliers.append(mult)
  35. i += 1
  36. else:
  37. emb_len = int(embedding.vec.shape[0])
  38. fixes.append((len(remade_tokens), embedding))
  39. remade_tokens += [0] * emb_len
  40. multipliers += [mult] * emb_len
  41. used_custom_terms.append((embedding.name, embedding.checksum()))
  42. i += embedding_length_in_tokens
  43. if len(remade_tokens) > maxlen - 2:
  44. vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
  45. ovf = remade_tokens[maxlen - 2:]
  46. overflowing_words = [vocab.get(int(x), "") for x in ovf]
  47. overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
  48. hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
  49. token_count = len(remade_tokens)
  50. remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
  51. remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
  52. cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
  53. multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
  54. multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
  55. remade_batch_tokens.append(remade_tokens)
  56. hijack_fixes.append(fixes)
  57. batch_multipliers.append(multipliers)
  58. return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
  59. def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
  60. batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
  61. self.hijack.comments += hijack_comments
  62. if len(used_custom_terms) > 0:
  63. self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
  64. self.hijack.fixes = hijack_fixes
  65. return self.process_tokens(remade_batch_tokens, batch_multipliers)