predispatch.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # mypy: ignore-errors
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """
  8. This module contains pre-dispatch wrappers for functorch operations
  9. that enable proper tracing in PT2 non-strict export/compile fx graph.
  10. """
  11. import torch
  12. from torch._C._functorch import (
  13. _add_batch_dim as _add_batch_dim_impl,
  14. _remove_batch_dim as _remove_batch_dim_impl,
  15. _vmap_decrement_nesting as _vmap_decrement_nesting_impl,
  16. _vmap_increment_nesting as _vmap_increment_nesting_impl,
  17. )
  18. def _add_batch_dim(self, batch_dim, level):
  19. """
  20. Thin wrapper around torch._C._add_batch_dim that is used to proxy in
  21. PT2 export/compile fx graph
  22. """
  23. from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
  24. mode = _maybe_find_pre_dispatch_tf_mode_for_export()
  25. if mode:
  26. return torch.overrides.handle_torch_function(
  27. _add_batch_dim, (self,), self, batch_dim, level
  28. )
  29. res = _add_batch_dim_impl(self, batch_dim, level)
  30. return res
  31. def _remove_batch_dim(self, level, batch_size, out_dim):
  32. """
  33. Thin wrapper around torch._C._remove_batch_dim that is used to proxy in
  34. PT2 export/compile fx graph
  35. """
  36. from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
  37. mode = _maybe_find_pre_dispatch_tf_mode_for_export()
  38. if mode:
  39. return torch.overrides.handle_torch_function(
  40. _remove_batch_dim, (self,), self, level, batch_size, out_dim
  41. )
  42. res = _remove_batch_dim_impl(self, level, batch_size, out_dim)
  43. return res
  44. def _vmap_increment_nesting(batch_size, randomness):
  45. """
  46. Thin wrapper around torch._C._vmap_increment_nesting that is used
  47. to proxy in export/compile graph
  48. """
  49. from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
  50. mode = _maybe_find_pre_dispatch_tf_mode_for_export()
  51. if mode:
  52. return torch.overrides.handle_torch_function(
  53. _vmap_increment_nesting, (batch_size,), batch_size, randomness
  54. )
  55. res = _vmap_increment_nesting_impl(batch_size, randomness)
  56. return res
  57. def _vmap_decrement_nesting():
  58. """
  59. Thin wrapper around torch._C._vmap_increment_nesting that is used
  60. to proxy in export/compile graph
  61. """
  62. from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
  63. mode = _maybe_find_pre_dispatch_tf_mode_for_export()
  64. if mode:
  65. return torch.overrides.handle_torch_function(
  66. _vmap_decrement_nesting,
  67. (),
  68. )
  69. return _vmap_decrement_nesting_impl()
  70. # Global variables for lazy_load_decompositions
  71. DECOMPOSITIONS_LOADED = False
  72. DECOMPOSITIONS_LOCK = None # Will be initialized when needed
  73. VMAP_DECOMPOSITIONS_LIB = None
  74. def lazy_load_decompositions():
  75. """
  76. Lazy loading of vmap decompositions with pre-dispatch support.
  77. """
  78. from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
  79. mode = _maybe_find_pre_dispatch_tf_mode_for_export()
  80. if mode:
  81. return torch.overrides.handle_torch_function(lazy_load_decompositions, ())
  82. global DECOMPOSITIONS_LOADED, DECOMPOSITIONS_LOCK, VMAP_DECOMPOSITIONS_LIB
  83. if DECOMPOSITIONS_LOADED:
  84. return
  85. # Initialize lock if needed
  86. if DECOMPOSITIONS_LOCK is None:
  87. import threading
  88. DECOMPOSITIONS_LOCK = threading.Lock()
  89. with DECOMPOSITIONS_LOCK:
  90. if DECOMPOSITIONS_LOADED:
  91. return
  92. import os
  93. if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__):
  94. DECOMPOSITIONS_LOADED = True
  95. return
  96. # use an alternate way to register an operator into the decomposition table
  97. # _register_jit_decomposition doesn't work for some operators, e.g. addr,
  98. # because the Tensor types generated cannot be unioned by torchscript
  99. # decomp should be type OpOverload
  100. VMAP_DECOMPOSITIONS_LIB = torch.library.Library(
  101. "aten", "IMPL", "FuncTorchBatched"
  102. )
  103. from torch._decomp import decomposition_table
  104. def _register_python_decomposition_vmap(decomp):
  105. if decomp in decomposition_table:
  106. VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp])
  107. else:
  108. raise RuntimeError(f"could not find decomposition for {decomp}")
  109. _register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
  110. _register_python_decomposition_vmap(
  111. torch.ops.aten.smooth_l1_loss_backward.default
  112. )
  113. _register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default)
  114. _register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default)
  115. _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default)
  116. _register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default)
  117. _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default)
  118. _register_python_decomposition_vmap(torch.ops.aten.addr.default)
  119. DECOMPOSITIONS_LOADED = True