setup.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. #!/usr/bin/env python
  2. from setuptools import find_packages, setup
  3. import os
  4. import subprocess
  5. import sys
  6. import time
  7. from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
  8. from utils.misc import gpu_is_available
  9. version_file = './basicsr/version.py'
  10. def readme():
  11. with open('README.md', encoding='utf-8') as f:
  12. content = f.read()
  13. return content
  14. def get_git_hash():
  15. def _minimal_ext_cmd(cmd):
  16. # construct minimal environment
  17. env = {}
  18. for k in ['SYSTEMROOT', 'PATH', 'HOME']:
  19. v = os.environ.get(k)
  20. if v is not None:
  21. env[k] = v
  22. # LANGUAGE is used on win32
  23. env['LANGUAGE'] = 'C'
  24. env['LANG'] = 'C'
  25. env['LC_ALL'] = 'C'
  26. out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
  27. return out
  28. try:
  29. out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
  30. sha = out.strip().decode('ascii')
  31. except OSError:
  32. sha = 'unknown'
  33. return sha
  34. def get_hash():
  35. if os.path.exists('.git'):
  36. sha = get_git_hash()[:7]
  37. elif os.path.exists(version_file):
  38. try:
  39. from version import __version__
  40. sha = __version__.split('+')[-1]
  41. except ImportError:
  42. raise ImportError('Unable to get git version')
  43. else:
  44. sha = 'unknown'
  45. return sha
  46. def write_version_py():
  47. content = """# GENERATED VERSION FILE
  48. # TIME: {}
  49. __version__ = '{}'
  50. __gitsha__ = '{}'
  51. version_info = ({})
  52. """
  53. sha = get_hash()
  54. with open('./basicsr/VERSION', 'r') as f:
  55. SHORT_VERSION = f.read().strip()
  56. VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
  57. version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
  58. with open(version_file, 'w') as f:
  59. f.write(version_file_str)
  60. def get_version():
  61. with open(version_file, 'r') as f:
  62. exec(compile(f.read(), version_file, 'exec'))
  63. return locals()['__version__']
  64. def make_cuda_ext(name, module, sources, sources_cuda=None):
  65. if sources_cuda is None:
  66. sources_cuda = []
  67. define_macros = []
  68. extra_compile_args = {'cxx': []}
  69. # if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
  70. if gpu_is_available or os.getenv('FORCE_CUDA', '0') == '1':
  71. define_macros += [('WITH_CUDA', None)]
  72. extension = CUDAExtension
  73. extra_compile_args['nvcc'] = [
  74. '-D__CUDA_NO_HALF_OPERATORS__',
  75. '-D__CUDA_NO_HALF_CONVERSIONS__',
  76. '-D__CUDA_NO_HALF2_OPERATORS__',
  77. ]
  78. sources += sources_cuda
  79. else:
  80. print(f'Compiling {name} without CUDA')
  81. extension = CppExtension
  82. return extension(
  83. name=f'{module}.{name}',
  84. sources=[os.path.join(*module.split('.'), p) for p in sources],
  85. define_macros=define_macros,
  86. extra_compile_args=extra_compile_args)
  87. def get_requirements(filename='requirements.txt'):
  88. with open(os.path.join('.', filename), 'r') as f:
  89. requires = [line.replace('\n', '') for line in f.readlines()]
  90. return requires
  91. if __name__ == '__main__':
  92. if '--cuda_ext' in sys.argv:
  93. ext_modules = [
  94. make_cuda_ext(
  95. name='deform_conv_ext',
  96. module='ops.dcn',
  97. sources=['src/deform_conv_ext.cpp'],
  98. sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
  99. make_cuda_ext(
  100. name='fused_act_ext',
  101. module='ops.fused_act',
  102. sources=['src/fused_bias_act.cpp'],
  103. sources_cuda=['src/fused_bias_act_kernel.cu']),
  104. make_cuda_ext(
  105. name='upfirdn2d_ext',
  106. module='ops.upfirdn2d',
  107. sources=['src/upfirdn2d.cpp'],
  108. sources_cuda=['src/upfirdn2d_kernel.cu']),
  109. ]
  110. sys.argv.remove('--cuda_ext')
  111. else:
  112. ext_modules = []
  113. write_version_py()
  114. setup(
  115. name='basicsr',
  116. version=get_version(),
  117. description='Open Source Image and Video Super-Resolution Toolbox',
  118. long_description=readme(),
  119. long_description_content_type='text/markdown',
  120. author='Xintao Wang',
  121. author_email='xintao.wang@outlook.com',
  122. keywords='computer vision, restoration, super resolution',
  123. url='https://github.com/xinntao/BasicSR',
  124. include_package_data=True,
  125. packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
  126. classifiers=[
  127. 'Development Status :: 4 - Beta',
  128. 'License :: OSI Approved :: Apache Software License',
  129. 'Operating System :: OS Independent',
  130. 'Programming Language :: Python :: 3',
  131. 'Programming Language :: Python :: 3.7',
  132. 'Programming Language :: Python :: 3.8',
  133. ],
  134. license='Apache License 2.0',
  135. setup_requires=['cython', 'numpy'],
  136. install_requires=get_requirements(),
  137. ext_modules=ext_modules,
  138. cmdclass={'build_ext': BuildExtension},
  139. zip_safe=False)