reductions.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. # mypy: allow-untyped-defs
  2. import multiprocessing
  3. import os
  4. import threading
  5. from multiprocessing import reduction
  6. from multiprocessing.util import register_after_fork
  7. from typing import Union
  8. import torch
  9. from torch._namedtensor_internals import check_serializing_named_tensor
  10. try:
  11. # Early load resource_sharer to prevent a partially initialized instance
  12. # from being inherited in a forked child process. The reduce_storage method
  13. # requires this module indirectly through DupFd(). The built-in mp.Queue
  14. # class pickles arguments in a background thread which may overlap with the
  15. # fork.
  16. import multiprocessing.resource_sharer
  17. except ImportError:
  18. pass
  19. class StorageWeakRef:
  20. r"""A weak reference to a Storage.
  21. The cdata member is a Python number containing the integer representation of
  22. the Storage pointer.
  23. """
  24. __slots__ = ["cdata", "_free_weak_ref"]
  25. def __init__(self, storage):
  26. self.cdata = storage._weak_ref()
  27. # Save a direct reference to _free_weak_ref because the `torch` module
  28. # might be cleared during Python shutdown before this module is cleared.
  29. self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
  30. @classmethod
  31. def from_weakref(cls, cdata):
  32. instance = cls.__new__(cls)
  33. instance.cdata = cdata
  34. instance._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
  35. return instance
  36. def expired(self):
  37. return torch.Storage._expired(self.cdata) # type: ignore[attr-defined]
  38. def __del__(self):
  39. self._free_weak_ref(self.cdata)
  40. def __hash__(self):
  41. return self.cdata
  42. def __eq__(self, other):
  43. if id(self) == id(other):
  44. return True
  45. return self.cdata == other.cdata
  46. class SharedCache(dict):
  47. """Dictionary from multiprocessing handles to StorageWeakRef."""
  48. def __init__(self) -> None:
  49. # free_dead_references() is called if the len exceeds the current
  50. # limit. The limit scales with the number of remaining live objects.
  51. self.limit = 128
  52. # `fork` inherits lock state, so in case we fork when the lock is held,
  53. # we register a function to reset the lock to a new object to avoid
  54. # possible deadlocks, following python multiprocessing library design.
  55. self._after_fork()
  56. register_after_fork(self, SharedCache._after_fork)
  57. def _after_fork(self):
  58. self.lock = threading.Lock()
  59. def get(self, key): # type: ignore[override]
  60. with self.lock:
  61. return dict.get(self, key)
  62. def __setitem__(self, key, storage_ref):
  63. with self.lock:
  64. dict.__setitem__(self, key, storage_ref)
  65. if len(self) > self.limit:
  66. self.free_dead_references()
  67. def free_dead_references(self):
  68. live = 0
  69. for key, storage_ref in list(self.items()):
  70. if storage_ref.expired():
  71. del self[key]
  72. else:
  73. live += 1
  74. self.limit = max(128, live * 2)
  75. # mapping from handles to StorageWeakRef objects
  76. shared_cache = SharedCache()
  77. def rebuild_event(device, handle):
  78. return torch.cuda.Event.from_ipc_handle(device, handle)
  79. def reduce_event(event):
  80. handle = event.ipc_handle()
  81. return (rebuild_event, (event.device, handle))
  82. def rebuild_tensor(cls, storage, metadata):
  83. storage_offset, size, stride, requires_grad = metadata
  84. t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
  85. if cls == torch.nn.parameter.Parameter:
  86. # we have to pass requires_grad into constructor, rather than set it as an
  87. # attribute later, because it's an important check for Integer Tensors to
  88. # have requires_grad=False (or else they raise an error)
  89. t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
  90. else:
  91. t.requires_grad = requires_grad
  92. return t
  93. def rebuild_meta_tensor(
  94. tensor_cls,
  95. tensor_size,
  96. tensor_stride,
  97. tensor_offset,
  98. dtype,
  99. storage_size_bytes,
  100. requires_grad,
  101. ):
  102. untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta")
  103. typed_storage = torch.TypedStorage(
  104. wrap_storage=untyped_storage, dtype=dtype, _internal=True
  105. )
  106. t = torch._utils._rebuild_tensor(
  107. typed_storage,
  108. tensor_offset,
  109. tensor_size,
  110. tensor_stride,
  111. )
  112. if tensor_cls == torch.nn.parameter.Parameter:
  113. # It is crucial for integer tensors to receive
  114. # the requires_grad=False as an argument in the constructor
  115. t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
  116. else:
  117. t.requires_grad = requires_grad
  118. return t
  119. def rebuild_cuda_tensor(
  120. tensor_cls,
  121. tensor_size,
  122. tensor_stride,
  123. tensor_offset,
  124. storage_cls,
  125. dtype,
  126. storage_device,
  127. storage_handle,
  128. storage_size_bytes,
  129. storage_offset_bytes,
  130. requires_grad,
  131. ref_counter_handle,
  132. ref_counter_offset,
  133. event_handle,
  134. event_sync_required,
  135. ):
  136. # If storage_handle is None, storage points to nullptr.
  137. if storage_handle is None or storage_size_bytes == 0:
  138. storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True)
  139. else:
  140. storage = storage_from_cache(
  141. storage_cls, (storage_handle, storage_offset_bytes)
  142. )
  143. if storage is None:
  144. torch.cuda._lazy_init()
  145. storage = storage_cls._new_shared_cuda(
  146. storage_device,
  147. storage_handle,
  148. storage_size_bytes,
  149. storage_offset_bytes,
  150. ref_counter_handle,
  151. ref_counter_offset,
  152. event_handle,
  153. event_sync_required,
  154. )
  155. shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(
  156. storage
  157. )
  158. else:
  159. # We already ref counting this Storage, but producer needs new ref-counters to be released.
  160. storage_cls._release_ipc_counter(
  161. ref_counter_handle, ref_counter_offset, device=storage_device
  162. )
  163. _storage = (
  164. storage
  165. if isinstance(storage, torch.UntypedStorage)
  166. else storage._untyped_storage
  167. )
  168. t = torch._utils._rebuild_tensor(
  169. torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True),
  170. tensor_offset,
  171. tensor_size,
  172. tensor_stride,
  173. )
  174. if tensor_cls == torch.nn.parameter.Parameter:
  175. # It is crucial for integer tensors to receive
  176. # the requires_grad=False as an argument in the constructor
  177. t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
  178. else:
  179. t.requires_grad = requires_grad
  180. return t
  181. def reduce_tensor(tensor):
  182. if tensor.requires_grad and not tensor.is_leaf:
  183. raise RuntimeError(
  184. "Cowardly refusing to serialize non-leaf tensor which requires_grad, "
  185. "since autograd does not support crossing process boundaries. "
  186. "If you just want to transfer the data, call detach() on the tensor "
  187. "before serializing (e.g., putting it on the queue)."
  188. )
  189. check_serializing_named_tensor(tensor)
  190. torch.utils.hooks.warn_if_has_hooks(tensor)
  191. # Note [CUDA IPC and the caching allocator]
  192. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  193. # When you send a CUDA tensor over IPC, you might expect that you will
  194. # get out the same storage from the other end. However, the CUDA caching
  195. # allocator makes it difficult to preserve this invariant. Consider
  196. # the following situation: a tensor of size 0x100 points to offset 0x20 of
  197. # a storage at 0xA100 of size 0x100. (For simplicity, all of these
  198. # sizes are given in bytes). HOWEVER, with the caching allocator, this storage
  199. # might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000.
  200. #
  201. # When we want to send this CUDA tensor over IPC, we must send the
  202. # *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just
  203. # the storage 0xA100 (because that is what CUDA supports). So, on the
  204. # other end, there simply isn't any way to say, "Wait, you gave me
  205. # a bigger region (0xA000) than the one I wanted (0xA100)".
  206. #
  207. # OK, so if you sent the cudaMalloc allocation, can you just wrap that up as
  208. # one storage itself? No, because this cudaMalloc allocation might contain
  209. # storages of mixed types: float, bytes, double... If you make the entire
  210. # allocation a single storage of a type A, we'll hit an error when constructing
  211. # a tensor of type B on the storage.
  212. #
  213. # cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the
  214. # receiver side. However, cudaIpcMemHandles from each device in a given process may
  215. # only be opened by one context per device per other process.
  216. # If we open and close a memory handle multiples times in a process, CUDA is allowed
  217. # to give it a different address; similarly, once we close the memory, we're not
  218. # allowed to access it(and the storage/tensor built on top of it), even if it is
  219. # still live in the original process. As we cannot make a cudaMalloc allocation
  220. # to a single storage in one go, this requires us to cache the device pointer for
  221. # each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep
  222. # the old ones alives.
  223. # See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html]
  224. #
  225. # This is fine, because all we need to do is to save our position in the allocation,
  226. # and reconstruct storage and tensor from it.
  227. # 0xA000 -> -------CUDA Allocation------
  228. # | |
  229. # | |
  230. # | |
  231. # | |
  232. # 0xA100 -> --------storage1 begin------
  233. # | |
  234. # 0xA120 -> --------tensor1 begin ------
  235. # | |
  236. # | |
  237. # | |
  238. # | |
  239. # | |
  240. # 0xA160 -> --------tensor1 end---------
  241. # | |
  242. # | |
  243. # | |
  244. # 0xA200 -> --------storage1 end--------
  245. # | |
  246. # 0xE000 -> --------CUDA allocation-----
  247. #
  248. # To send tensor1, the following info are required from sender to receiver for
  249. # storage reconstruction.
  250. # 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process).
  251. # basePtr may not be exactly 0xA000 since it's a different process.
  252. # 2. offset(0xA100) of storage1 in the CUDA allocation.
  253. # 3. size of storage1(0x100).
  254. #
  255. # On receiver side:
  256. # 1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage
  257. # of the same type using (basePtr, offset, size).
  258. # 2. we can reconstruct the tensor on top of the reconstructed storage
  259. # Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100))
  260. #
  261. # This strategy has a few implications:
  262. #
  263. # 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one
  264. # go (non-compositionally), and this requires to have a global map
  265. # memHandle -> devPtr for each process.
  266. #
  267. # 2. We MUST NOT let the new IPC tensor be resizable. Originally, a resize
  268. # of the storage beyond 0x100 would merely have caused us to do a
  269. # reallocation. You don't really want to do this, but if you did,
  270. # all that would happen is that you would lose IPC sharing. But if
  271. # you do this in the new world, we will happily let you write out of
  272. # bounds of your "allocation", clobbering unrelated data in the cached
  273. # allocator block. BAD!
  274. #
  275. # By the way, in old versions of PyTorch, we supported this situation
  276. # natively using a "storage view", which permitted multiple storages to be
  277. # views on each other. But this was the *only* use of storage views, so we
  278. # eliminated it so that we could just use tensor views to implement the same
  279. # thing.
  280. #
  281. # TODO: Handle distinguishing between subclass and non-subclass versions of NT better
  282. # https://github.com/pytorch/pytorch/issues/110543
  283. from torch.nested._internal.nested_tensor import NestedTensor
  284. if tensor.is_nested and not isinstance(tensor, NestedTensor):
  285. return reduce_nested_tensor(tensor)
  286. if tensor.layout in {
  287. torch.sparse_coo,
  288. torch.sparse_csr,
  289. torch.sparse_bsr,
  290. torch.sparse_csc,
  291. torch.sparse_bsc,
  292. }:
  293. return reduce_sparse_tensor(tensor)
  294. storage = tensor._typed_storage()
  295. if storage._untyped_storage.device.type == "cuda":
  296. (
  297. device,
  298. handle,
  299. storage_size_bytes,
  300. storage_offset_bytes,
  301. ref_counter_handle,
  302. ref_counter_offset,
  303. event_handle,
  304. event_sync_required,
  305. ) = storage._share_cuda_()
  306. tensor_offset = tensor.storage_offset()
  307. shared_cache[handle] = StorageWeakRef(storage)
  308. # _backward_hooks purposely omitted here, see
  309. # Note [Don't serialize hooks]
  310. return (
  311. rebuild_cuda_tensor,
  312. (
  313. type(tensor),
  314. tensor.size(),
  315. tensor.stride(),
  316. tensor_offset, # tensor offset in its storage
  317. type(storage),
  318. tensor.dtype,
  319. device,
  320. handle, # identifier which CUDA allocation is the storage in.
  321. storage_size_bytes, # size(in bytes) of the storage
  322. storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation
  323. tensor.requires_grad,
  324. ref_counter_handle,
  325. ref_counter_offset,
  326. event_handle,
  327. event_sync_required,
  328. ),
  329. )
  330. elif storage._untyped_storage.device.type == "meta":
  331. return (
  332. rebuild_meta_tensor,
  333. (
  334. type(tensor),
  335. tensor.size(),
  336. tensor.stride(),
  337. tensor.storage_offset(),
  338. tensor.dtype,
  339. tensor.untyped_storage().size(),
  340. tensor.requires_grad,
  341. ),
  342. )
  343. # _backward_hooks purposely omitted here, see Note [Don't serialize hooks]
  344. metadata = (
  345. tensor.storage_offset(),
  346. tensor.size(),
  347. tensor.stride(),
  348. tensor.requires_grad,
  349. )
  350. return (rebuild_tensor, (type(tensor), storage, metadata))
  351. def rebuild_nested_tensor(
  352. rebuild_buffer_func,
  353. rebuild_buffer_args,
  354. rebuild_sizes_func,
  355. rebuild_sizes_args,
  356. rebuild_strides_func,
  357. rebuild_strides_args,
  358. rebuild_offsets_func,
  359. rebuild_offsets_args,
  360. ):
  361. buffer = rebuild_buffer_func(*rebuild_buffer_args)
  362. sizes = rebuild_sizes_func(*rebuild_sizes_args)
  363. strides = rebuild_strides_func(*rebuild_strides_args)
  364. offsets = rebuild_offsets_func(*rebuild_offsets_args)
  365. return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets)
  366. def reduce_nested_tensor(nt):
  367. rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values())
  368. rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size())
  369. rebuild_strides_func, rebuild_strides_args = reduce_tensor(
  370. nt._nested_tensor_strides()
  371. )
  372. rebuild_offsets_func, rebuild_offsets_args = reduce_tensor(
  373. nt._nested_tensor_storage_offsets()
  374. )
  375. return (
  376. rebuild_nested_tensor,
  377. (
  378. rebuild_buffer_func,
  379. rebuild_buffer_args,
  380. rebuild_sizes_func,
  381. rebuild_sizes_args,
  382. rebuild_strides_func,
  383. rebuild_strides_args,
  384. rebuild_offsets_func,
  385. rebuild_offsets_args,
  386. ),
  387. )
  388. def rebuild_sparse_coo_tensor(
  389. rebuild_indices_func,
  390. rebuild_indices_args,
  391. rebuild_values_func,
  392. rebuild_values_args,
  393. shape,
  394. is_coalesced,
  395. ):
  396. indices = rebuild_indices_func(*rebuild_indices_args)
  397. values = rebuild_values_func(*rebuild_values_args)
  398. return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced)
  399. def rebuild_sparse_compressed_tensor(
  400. rebuild_compressed_indices_func,
  401. rebuild_compressed_indices_args,
  402. rebuild_plain_indices_func,
  403. rebuild_plain_indices_args,
  404. rebuild_values_func,
  405. rebuild_values_args,
  406. shape,
  407. layout,
  408. ):
  409. compressed_indices = rebuild_compressed_indices_func(
  410. *rebuild_compressed_indices_args
  411. )
  412. plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args)
  413. values = rebuild_values_func(*rebuild_values_args)
  414. return torch.sparse_compressed_tensor(
  415. compressed_indices, plain_indices, values, shape, layout=layout
  416. )
  417. def reduce_sparse_tensor(sparse):
  418. if sparse.layout is torch.sparse_coo:
  419. rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices())
  420. rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values())
  421. return (
  422. rebuild_sparse_coo_tensor,
  423. (
  424. rebuild_indices_func,
  425. rebuild_indices_args,
  426. rebuild_values_func,
  427. rebuild_values_args,
  428. sparse.shape,
  429. sparse.is_coalesced(),
  430. ),
  431. )
  432. else:
  433. if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}:
  434. compressed_indices = sparse.crow_indices()
  435. plain_indices = sparse.col_indices()
  436. elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}:
  437. compressed_indices = sparse.ccol_indices()
  438. plain_indices = sparse.row_indices()
  439. else:
  440. raise NotImplementedError(sparse.layout)
  441. (
  442. rebuild_compressed_indices_func,
  443. rebuild_compressed_indices_args,
  444. ) = reduce_tensor(compressed_indices)
  445. rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor(
  446. plain_indices
  447. )
  448. rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values())
  449. return (
  450. rebuild_sparse_compressed_tensor,
  451. (
  452. rebuild_compressed_indices_func,
  453. rebuild_compressed_indices_args,
  454. rebuild_plain_indices_func,
  455. rebuild_plain_indices_args,
  456. rebuild_values_func,
  457. rebuild_values_args,
  458. sparse.shape,
  459. sparse.layout,
  460. ),
  461. )
  462. def fd_id(fd):
  463. # Returns a tuple which uniquely identifies a file descriptor. In Mac OS,
  464. # this doesn't work with shared memory handles, which is why we don't
  465. # support the "file_descriptor" sharing method on that platform.
  466. stat = os.fstat(fd)
  467. return (stat.st_ino, stat.st_dev)
  468. def storage_from_cache(cls, key):
  469. storage_ref = shared_cache.get(key)
  470. if storage_ref is None:
  471. return None
  472. return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
  473. def rebuild_storage_fd(cls, df, size):
  474. fd = df.detach()
  475. try:
  476. storage = storage_from_cache(cls, fd_id(fd))
  477. if storage is not None:
  478. return storage
  479. storage = cls._new_shared_fd_cpu(fd, size)
  480. shared_cache[fd_id(fd)] = StorageWeakRef(storage)
  481. return storage
  482. finally:
  483. os.close(fd)
  484. def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
  485. storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache(
  486. cls, handle
  487. )
  488. if storage is not None:
  489. return storage._shared_decref()
  490. if dtype is None:
  491. storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size)
  492. else:
  493. byte_size = size * torch._utils._element_size(dtype)
  494. untyped_storage: torch.UntypedStorage = (
  495. torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
  496. )
  497. storage = torch.TypedStorage(
  498. wrap_storage=untyped_storage, dtype=dtype, _internal=True
  499. )
  500. shared_cache[handle] = StorageWeakRef(storage)
  501. return storage._shared_decref()
  502. def rebuild_storage_empty(cls):
  503. return cls()
  504. def rebuild_typed_storage(storage, dtype):
  505. return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True)
  506. # Use for torch.storage.TypedStorage
  507. def reduce_typed_storage(storage):
  508. return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype))
  509. def rebuild_typed_storage_child(storage, storage_type):
  510. return storage_type(wrap_storage=storage, _internal=True)
  511. # Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage
  512. def reduce_typed_storage_child(storage):
  513. return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage)))
  514. def reduce_storage(storage):
  515. from . import get_sharing_strategy
  516. if storage.is_cuda:
  517. raise RuntimeError(
  518. "Cannot pickle CUDA storage; try pickling a CUDA tensor instead"
  519. )
  520. elif storage.device.type == "meta":
  521. raise RuntimeError(
  522. "Cannot pickle meta storage; try pickling a meta tensor instead"
  523. )
  524. elif get_sharing_strategy() == "file_system":
  525. metadata = storage._share_filename_cpu_()
  526. cache_key = metadata[1]
  527. rebuild = rebuild_storage_filename
  528. if isinstance(storage, torch.TypedStorage):
  529. metadata += (storage.dtype,)
  530. storage._shared_incref()
  531. elif storage.size() == 0:
  532. # This is special cased because Empty tensors
  533. # (with size 0) cannot be mmapped.
  534. return (rebuild_storage_empty, (type(storage),))
  535. else:
  536. fd, size = storage._share_fd_cpu_()
  537. df = multiprocessing.reduction.DupFd(fd)
  538. cache_key = fd_id(fd)
  539. metadata = (df, size)
  540. rebuild = rebuild_storage_fd # type: ignore[assignment]
  541. shared_cache[cache_key] = StorageWeakRef(storage)
  542. return (rebuild, (type(storage),) + metadata)
  543. def init_reductions():
  544. reduction.register(torch.cuda.Event, reduce_event)
  545. for t in torch._storage_classes:
  546. if t.__name__ == "UntypedStorage":
  547. reduction.register(t, reduce_storage)
  548. else:
  549. reduction.register(t, reduce_typed_storage_child)
  550. reduction.register(torch.storage.TypedStorage, reduce_typed_storage)
  551. for t in torch._tensor_classes:
  552. reduction.register(t, reduce_tensor)
  553. # TODO: Maybe this should be in tensor_classes? :)
  554. reduction.register(torch.Tensor, reduce_tensor)
  555. from torch.nn.parameter import Parameter
  556. reduction.register(Parameter, reduce_tensor)