backend_registration.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. # mypy: allow-untyped-defs
  2. from typing import Optional, Union
  3. import torch
  4. from torch._C import _get_privateuse1_backend_name, _rename_privateuse1_backend
  5. from torch.overrides import handle_torch_function, has_torch_function_unary
  6. __all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
  7. # TODO: Should use `torch._C._get_privateuse1_backend_name()` to get
  8. # renamed-backend name for `privateuse1`, but the func will cause an
  9. # error with torch.jit.script, so we use the global variable named
  10. # `_privateuse1_backend_name`.
  11. _privateuse1_backend_name = "privateuseone"
  12. def rename_privateuse1_backend(backend_name: str) -> None:
  13. r"""
  14. Rename the privateuse1 backend device to make it more convenient to use as a device name within PyTorch APIs.
  15. The steps are:
  16. (1) (In C++) implement kernels for various torch operations, and register them
  17. to the PrivateUse1 dispatch key.
  18. (2) (In python) call torch.utils.rename_privateuse1_backend("foo")
  19. You can now use "foo" as an ordinary device string in python.
  20. Note: this API can only be called once per process. Attempting to change
  21. the external backend after it's already been set will result in an error.
  22. Note(AMP): If you want to support AMP on your device, you can register a custom backend module.
  23. The backend must register a custom backend module with ``torch._register_device_module("foo", BackendModule)``.
  24. BackendModule needs to have the following API's:
  25. (1) ``get_amp_supported_dtype() -> List[torch.dtype]``
  26. get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype.
  27. Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's:
  28. (1) ``_is_in_bad_fork() -> bool``
  29. Return ``True`` if now it is in bad_fork, else return ``False``.
  30. (2) ``manual_seed_all(seed int) -> None``
  31. Sets the seed for generating random numbers for your devices.
  32. (3) ``device_count() -> int``
  33. Returns the number of "foo"s available.
  34. (4) ``get_rng_state(device: Union[int, str, torch.device] = 'foo') -> Tensor``
  35. Returns a list of ByteTensor representing the random number states of all devices.
  36. (5) ``set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None``
  37. Sets the random number generator state of the specified "foo" device.
  38. And there are some common funcs:
  39. (1) ``is_available() -> bool``
  40. Returns a bool indicating if "foo" is currently available.
  41. (2) ``current_device() -> int``
  42. Returns the index of a currently selected device.
  43. For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend
  44. For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example
  45. Example::
  46. >>> # xdoctest: +SKIP("failing")
  47. >>> torch.utils.rename_privateuse1_backend("foo")
  48. # This will work, assuming that you've implemented the right C++ kernels
  49. # to implement torch.ones.
  50. >>> a = torch.ones(2, device="foo")
  51. """
  52. _rename_privateuse1_backend(backend_name)
  53. global _privateuse1_backend_name
  54. _privateuse1_backend_name = backend_name
  55. def _check_register_once(module, attr):
  56. if hasattr(module, attr):
  57. raise RuntimeError(
  58. f"The custom device module of {module} has already been registered with {attr}"
  59. )
  60. def _normalization_device(
  61. custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None
  62. ) -> int:
  63. def _get_current_device_index():
  64. _get_device_index = "current_device"
  65. if hasattr(torch, custom_backend_name) and hasattr(
  66. getattr(torch, custom_backend_name), _get_device_index
  67. ):
  68. return getattr(getattr(torch, custom_backend_name), _get_device_index)()
  69. else:
  70. # The default device index is 0.
  71. return 0
  72. if device is None:
  73. return _get_current_device_index()
  74. # if isinstance(device, str), this means that the parameter passed in is in the string format "foo:0"
  75. # convert str object to torch.device object, and then process it uniformly
  76. elif isinstance(device, str):
  77. device = torch.device(device)
  78. # variable device can only be torch.device type or int type
  79. if isinstance(device, torch.device):
  80. if device.type != custom_backend_name:
  81. raise RuntimeError(f"Invalid device, must be {custom_backend_name} device")
  82. elif device.index is None:
  83. device_idx = _get_current_device_index()
  84. else:
  85. device_idx = device.index
  86. # if isinstance(device, int), we can take the index number directly
  87. else:
  88. device_idx = device
  89. return device_idx
  90. def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
  91. @property # type: ignore[misc]
  92. def wrap_tensor_backend(self: torch.Tensor) -> bool:
  93. if has_torch_function_unary(self):
  94. # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
  95. return handle_torch_function(wrap_tensor_backend.__get__, (self,), self) # type: ignore[attr-defined]
  96. return self.device.type == custom_backend_name
  97. _check_register_once(torch.Tensor, f"is_{custom_backend_name}")
  98. wrap_tensor_backend.fget.__name__ = f"is_{custom_backend_name}" # type: ignore[attr-defined]
  99. setattr(torch.Tensor, f"is_{custom_backend_name}", wrap_tensor_backend)
  100. def wrap_tensor_to(
  101. self: torch.Tensor,
  102. device: Optional[Union[int, torch.device]] = None,
  103. non_blocking=False,
  104. **kwargs,
  105. ) -> torch.Tensor:
  106. r"""Perform Tensor device conversion. Call the to operator implementation.
  107. .. note::
  108. If the ``self`` Tensor already
  109. has the correct :class:`torch.device`, then ``self`` is returned.
  110. Otherwise, the returned tensor is a copy of ``self`` with the desired :class:`torch.device`.
  111. Args:
  112. device (int, optional): if specified, all parameters will be copied to that device
  113. non_blocking (bool): If ``True`` and the source is in pinned memory,
  114. the copy will be asynchronous with respect to the host. Otherwise,
  115. the argument has no effect.
  116. **kwargs (dict): For compatibility, may contain the key ``memory_format`` argument.
  117. """
  118. if has_torch_function_unary(self):
  119. return handle_torch_function(
  120. wrap_tensor_to,
  121. (self,),
  122. self,
  123. device=device,
  124. non_blocking=False,
  125. **kwargs,
  126. )
  127. device_idx = _normalization_device(custom_backend_name, device)
  128. return self.to(
  129. device=torch.device(f"{custom_backend_name}:{device_idx}"),
  130. non_blocking=non_blocking,
  131. **kwargs,
  132. )
  133. _check_register_once(torch.Tensor, custom_backend_name)
  134. wrap_tensor_to.__name__ = custom_backend_name
  135. setattr(torch.Tensor, custom_backend_name, wrap_tensor_to)
  136. def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
  137. # Generate Module attributes and methods depends on Tensor methods,
  138. # so we need to check whether Tensor methods is already registered.
  139. if not hasattr(torch.Tensor, custom_backend_name):
  140. raise RuntimeError(
  141. f"Can not automatically generate {custom_backend_name}() method for torch.nn.Module."
  142. f"Because torch.Tensor doesn't has the method {custom_backend_name}()."
  143. f"For this error, you can try setting for_tensor=True."
  144. )
  145. def wrap_module_to(
  146. self: torch.nn.modules.module.T,
  147. device: Optional[Union[int, torch.device]] = None,
  148. ) -> torch.nn.modules.module.T:
  149. r"""Move all model parameters and buffers to the custom device.
  150. This also makes associated parameters and buffers different objects. So
  151. it should be called before constructing optimizer if the module will
  152. live on device while being optimized.
  153. .. note::
  154. This method modifies the module in-place.
  155. Args:
  156. device (int, optional): if specified, all parameters will be copied to that device
  157. """
  158. return self._apply(lambda t: getattr(t, custom_backend_name)(device))
  159. _check_register_once(torch.nn.Module, custom_backend_name)
  160. setattr(torch.nn.Module, custom_backend_name, wrap_module_to)
  161. def _generate_packed_sequence_methods_for_privateuse1_backend(
  162. custom_backend_name: str,
  163. ) -> None:
  164. # Generate PackedSequence Module attributes and methods depends on Tensor methods,
  165. # so we need to check whether Tensor methods is already registered.
  166. if not hasattr(torch.Tensor, f"is_{custom_backend_name}") or not hasattr(
  167. torch.Tensor, custom_backend_name
  168. ):
  169. raise RuntimeError(
  170. f"Can not automatically generate is_{custom_backend_name}() or "
  171. f"{custom_backend_name}() method for torch.nn.utils.rnn.PackedSequence."
  172. f"Because torch.Tensor doesn't has the method is_{custom_backend_name}()"
  173. f"or {custom_backend_name}()."
  174. f"For this error, you can try setting for_tensor=True."
  175. )
  176. @property # type: ignore[misc]
  177. def wrap_tensor_backend(self: torch.nn.utils.rnn.PackedSequence) -> bool:
  178. return self.data.device.type == custom_backend_name
  179. _check_register_once(torch.nn.utils.rnn.PackedSequence, f"is_{custom_backend_name}")
  180. setattr(
  181. torch.nn.utils.rnn.PackedSequence,
  182. f"is_{custom_backend_name}",
  183. wrap_tensor_backend,
  184. )
  185. def wrap_module_to(
  186. self: torch.nn.utils.rnn.PackedSequence, *args, **kwargs
  187. ) -> torch.nn.utils.rnn.PackedSequence:
  188. r"""Move all model parameters and buffers to the custom device.
  189. This also makes associated parameters and buffers different objects. So
  190. it should be called before constructing optimizer if the module will
  191. live on device while being optimized.
  192. .. note::
  193. This method modifies the module in-place.
  194. Args:
  195. device (int, optional): if specified, all parameters will be copied to that device
  196. """
  197. ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
  198. *args, **kwargs
  199. )
  200. if ex.device.type == custom_backend_name:
  201. return self.to(*args, **kwargs)
  202. kwargs.update({"device": custom_backend_name})
  203. return self.to(*args, **kwargs)
  204. _check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name)
  205. setattr(torch.nn.utils.rnn.PackedSequence, custom_backend_name, wrap_module_to)
  206. def _generate_storage_methods_for_privateuse1_backend(
  207. custom_backend_name: str, unsupported_dtype: Optional[list[torch.dtype]] = None
  208. ) -> None:
  209. # Attribute is registered in the _StorageBase class
  210. # and UntypedStorage obtains through inheritance.
  211. @property # type: ignore[misc]
  212. def wrap_storage_backend(self: torch.storage._StorageBase) -> bool:
  213. r"""Return the internal :class:`torch.UntypedStorage`."""
  214. return self.device.type == custom_backend_name
  215. _check_register_once(torch.storage._StorageBase, f"is_{custom_backend_name}")
  216. setattr(
  217. torch.storage._StorageBase, f"is_{custom_backend_name}", wrap_storage_backend
  218. )
  219. def wrap_storage_to(self, device=None, non_blocking=False):
  220. r"""Return a copy of this object in custom device memory.
  221. If this object is already in device memory and on the correct device, then
  222. no copy is performed and the original object is returned.
  223. Args:
  224. device (int): The destination device id. Defaults to the current device.
  225. non_blocking (bool): If ``True`` and the source is in pinned memory,
  226. the copy will be asynchronous with respect to the host. Otherwise,
  227. the argument has no effect.
  228. """
  229. # There should be a judgment related to storage device and a judgment related to storage type,
  230. # but it depends on the extended function, so this part is temporarily omitted in the automatic generation.
  231. device_idx = _normalization_device(custom_backend_name, device)
  232. if getattr(self, f"is_{custom_backend_name}"):
  233. # storage has already on expected device.
  234. if self.get_device() == device_idx:
  235. return self
  236. # For sparse storage, custom need to extend the implementation by themselves.
  237. if self.is_sparse:
  238. raise RuntimeError(
  239. f"Can not support a sparse storage move to {custom_backend_name} backend"
  240. )
  241. # create untyped_storage and copy data
  242. untyped_storage = torch.UntypedStorage(
  243. self.size(), device=torch.device(f"{custom_backend_name}:{device_idx}")
  244. )
  245. untyped_storage.copy_(self, non_blocking)
  246. return untyped_storage
  247. _check_register_once(torch.storage._StorageBase, custom_backend_name)
  248. setattr(torch.storage._StorageBase, custom_backend_name, wrap_storage_to)
  249. # Register the corresponding attribute for the TypedStorage class.
  250. # When the TypedStorage class is removed, the registration is also removed.
  251. @property # type: ignore[misc]
  252. def wrap_typed_storage_backend(self: torch.storage.TypedStorage) -> bool:
  253. torch.storage._warn_typed_storage_removal()
  254. return self._untyped_storage.device.type == custom_backend_name
  255. _check_register_once(torch.TypedStorage, f"is_{custom_backend_name}")
  256. setattr(
  257. torch.storage.TypedStorage,
  258. f"is_{custom_backend_name}",
  259. wrap_typed_storage_backend,
  260. )
  261. def wrap_typed_storage_to(
  262. self: torch.storage.TypedStorage, device=None, non_blocking=False, **kwargs
  263. ) -> torch.storage.TypedStorage:
  264. torch.storage._warn_typed_storage_removal()
  265. if unsupported_dtype and self.dtype in unsupported_dtype:
  266. raise RuntimeError(
  267. f"Cannot create {custom_backend_name} storage "
  268. f"as {self.dtype} dtype is not supported by this backend"
  269. )
  270. custom_backend_storage: torch.UntypedStorage = getattr(
  271. self._untyped_storage, custom_backend_name
  272. )(device, non_blocking, **kwargs)
  273. return self._new_wrapped_storage(custom_backend_storage)
  274. _check_register_once(torch.TypedStorage, custom_backend_name)
  275. setattr(torch.TypedStorage, custom_backend_name, wrap_typed_storage_to)
  276. def generate_methods_for_privateuse1_backend(
  277. for_tensor: bool = True,
  278. for_module: bool = True,
  279. for_packed_sequence: bool = True,
  280. for_storage: bool = False,
  281. unsupported_dtype: Optional[list[torch.dtype]] = None,
  282. ) -> None:
  283. r"""
  284. Automatically generate attributes and methods for the custom backend after rename privateuse1 backend.
  285. In the default scenario, storage-related methods will not be generated automatically.
  286. When you implement kernels for various torch operations, and register them to the PrivateUse1 dispatch key.
  287. And call the function torch.rename_privateuse1_backend("foo") to rename your backend name.
  288. At this point, you can easily register specific methods and attributes by calling this function.
  289. Just like torch.Tensor.foo(), torch.Tensor.is_foo, torch.Storage.foo(), torch.Storage.is_foo.
  290. Note: We recommend you use generic functions (check devices are equal or to(device=)).
  291. We provide these methods for convenience only and they will be "monkey patched" onto the objects
  292. and so will not be properly typed. For Storage methods generate, if you need to support sparse data storage,
  293. you need to extend the implementation yourself.
  294. Args:
  295. for_tensor (bool): whether register related methods for torch.Tensor class.
  296. for_module (bool): whether register related methods for torch.nn.Module class.
  297. for_storage (bool): whether register related methods for torch.Storage class.
  298. unsupported_dtype (List[torch.dtype]): takes effect only when the storage method needs to be generated,
  299. indicating that the storage does not support the torch.dtype type.
  300. Example::
  301. >>> # xdoctest: +SKIP("failing")
  302. >>> torch.utils.rename_privateuse1_backend("foo")
  303. >>> torch.utils.generate_methods_for_privateuse1_backend()
  304. # Then automatically generate backend-related attributes and methods.
  305. >>> a = torch.tensor(2).foo()
  306. >>> a.is_foo
  307. >>> hasattr(torch.nn.Module, 'foo')
  308. """
  309. custom_backend_name = _get_privateuse1_backend_name()
  310. if for_tensor:
  311. _generate_tensor_methods_for_privateuse1_backend(custom_backend_name)
  312. if for_module:
  313. _generate_module_methods_for_privateuse1_backend(custom_backend_name)
  314. if for_storage:
  315. _generate_storage_methods_for_privateuse1_backend(
  316. custom_backend_name, unsupported_dtype
  317. )
  318. if for_packed_sequence:
  319. _generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name)
  320. def _get_custom_mod_func(func_name: str):
  321. r"""
  322. Return the func named `func_name` defined in custom device module. If not defined,
  323. return `None`. And the func is registered with `torch.utils.rename_privateuse1_backend('foo')`
  324. and `torch._register_device_module('foo', BackendModule)`.
  325. If the custom device module or the func is not defined, it will give warning or error message.
  326. Args:
  327. func_name (str): return the callable func named func_name defined in custom device module.
  328. Example::
  329. class DummyfooModule:
  330. @staticmethod
  331. def is_available():
  332. return True
  333. @staticmethod
  334. def func_name(*args, **kwargs):
  335. ....
  336. torch.utils.rename_privateuse1_backend("foo")
  337. torch._register_device_module("foo", DummyfooModule)
  338. foo_is_available_func = torch.utils.backend_registration._get_custom_mod_func("is_available")
  339. if foo_is_available_func:
  340. foo_is_available = foo_is_available_func()
  341. func_ = torch.utils.backend_registration._get_custom_mod_func("func_name")
  342. if func_:
  343. result = func_(*args, **kwargs)
  344. Attention: This function is not meant to be used directly by users, which is why
  345. it is marked as private. It is a convenience function for backend implementers to
  346. more easily call the hooks into their backend extensions.
  347. """
  348. assert isinstance(func_name, str), (
  349. f"func_name must be `str`, but got `{type(func_name)}`."
  350. )
  351. backend_name = _get_privateuse1_backend_name()
  352. custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type]
  353. function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type]
  354. if custom_device_mod is None or function is None:
  355. message = f"Try to call torch.{backend_name}.{func_name}. The backend must register a custom backend "
  356. message += f"module with `torch._register_device_module('{backend_name}', BackendModule)`. And "
  357. message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n"
  358. raise RuntimeError(message)
  359. return function