extract_segmentation.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import sys, os
  2. import numpy as np
  3. import scipy
  4. import torch
  5. import torch.nn as nn
  6. from scipy import ndimage
  7. from tqdm import tqdm, trange
  8. from PIL import Image
  9. import torch.hub
  10. import torchvision
  11. import torch.nn.functional as F
  12. # download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from
  13. # https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth
  14. # and put the path here
  15. CKPT_PATH = "TODO"
  16. rescale = lambda x: (x + 1.) / 2.
  17. def rescale_bgr(x):
  18. x = (x+1)*127.5
  19. x = torch.flip(x, dims=[0])
  20. return x
  21. class COCOStuffSegmenter(nn.Module):
  22. def __init__(self, config):
  23. super().__init__()
  24. self.config = config
  25. self.n_labels = 182
  26. model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels)
  27. ckpt_path = CKPT_PATH
  28. model.load_state_dict(torch.load(ckpt_path))
  29. self.model = model
  30. normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
  31. self.image_transform = torchvision.transforms.Compose([
  32. torchvision.transforms.Lambda(lambda image: torch.stack(
  33. [normalize(rescale_bgr(x)) for x in image]))
  34. ])
  35. def forward(self, x, upsample=None):
  36. x = self._pre_process(x)
  37. x = self.model(x)
  38. if upsample is not None:
  39. x = torch.nn.functional.upsample_bilinear(x, size=upsample)
  40. return x
  41. def _pre_process(self, x):
  42. x = self.image_transform(x)
  43. return x
  44. @property
  45. def mean(self):
  46. # bgr
  47. return [104.008, 116.669, 122.675]
  48. @property
  49. def std(self):
  50. return [1.0, 1.0, 1.0]
  51. @property
  52. def input_size(self):
  53. return [3, 224, 224]
  54. def run_model(img, model):
  55. model = model.eval()
  56. with torch.no_grad():
  57. segmentation = model(img, upsample=(img.shape[2], img.shape[3]))
  58. segmentation = torch.argmax(segmentation, dim=1, keepdim=True)
  59. return segmentation.detach().cpu()
  60. def get_input(batch, k):
  61. x = batch[k]
  62. if len(x.shape) == 3:
  63. x = x[..., None]
  64. x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
  65. return x.float()
  66. def save_segmentation(segmentation, path):
  67. # --> class label to uint8, save as png
  68. os.makedirs(os.path.dirname(path), exist_ok=True)
  69. assert len(segmentation.shape)==4
  70. assert segmentation.shape[0]==1
  71. for seg in segmentation:
  72. seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8)
  73. seg = Image.fromarray(seg)
  74. seg.save(path)
  75. def iterate_dataset(dataloader, destpath, model):
  76. os.makedirs(destpath, exist_ok=True)
  77. num_processed = 0
  78. for i, batch in tqdm(enumerate(dataloader), desc="Data"):
  79. try:
  80. img = get_input(batch, "image")
  81. img = img.cuda()
  82. seg = run_model(img, model)
  83. path = batch["relative_file_path_"][0]
  84. path = os.path.splitext(path)[0]
  85. path = os.path.join(destpath, path + ".png")
  86. save_segmentation(seg, path)
  87. num_processed += 1
  88. except Exception as e:
  89. print(e)
  90. print("but anyhow..")
  91. print("Processed {} files. Bye.".format(num_processed))
  92. from taming.data.sflckr import Examples
  93. from torch.utils.data import DataLoader
  94. if __name__ == "__main__":
  95. dest = sys.argv[1]
  96. batchsize = 1
  97. print("Running with batch-size {}, saving to {}...".format(batchsize, dest))
  98. model = COCOStuffSegmenter({}).cuda()
  99. print("Instantiated model.")
  100. dataset = Examples()
  101. dloader = DataLoader(dataset, batch_size=batchsize)
  102. iterate_dataset(dataloader=dloader, destpath=dest, model=model)
  103. print("done.")