FunctionalTensorWrapper.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. #pragma once
  2. #include <ATen/ArrayRef.h>
  3. #include <ATen/FunctionalStorageImpl.h>
  4. #include <ATen/core/IListRef.h>
  5. #include <ATen/core/List.h>
  6. #include <ATen/core/boxing/BoxedKernel.h>
  7. #include <ATen/core/boxing/impl/boxing.h>
  8. #include <ATen/core/dispatch/Dispatcher.h>
  9. #include <c10/core/DispatchKey.h>
  10. namespace at {
  11. // Note [Functionalization Pass In Core]
  12. // The Functionalization pass is used to remove aliasing from a pytorch program.
  13. //
  14. // This is useful for backends that don't support aliasing, like XLA and Vulkan.
  15. // It's also necessary in order to remove mutation from a program, which is
  16. // needed in Functorch.
  17. //
  18. // Consider this program:
  19. // a = torch.ones(...)
  20. // b = a.view(...)
  21. // b.add_(1)
  22. //
  23. // In this program, b is meant to alias with a due to the use of view(). At the
  24. // end of the program, both a and b are full of 2's. However, backends that
  25. // don't support aliasing aren't able to correctly implement the view()
  26. // operator. Instead, they can opt into the Functionalization pass, which will
  27. // sit between the user and the backend, and provide the necessary aliasing
  28. // logic.
  29. //
  30. // The functionalization pass will turn the above program into a slightly
  31. // different program that has the same semantics, transparently to the user,
  32. // that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
  33. // a.view_copy(...) # view() replaced with view_copy(). Backends like
  34. // XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
  35. // pass machinery knows that a and b are aliased - it applies b's mutation to a
  36. // too.
  37. //
  38. // So, how does the functionalization pass keep track of which tensors are
  39. // aliased? The pass works by wrapping EVERY tensor in the program inside of a
  40. // FunctionalTensorWrapper, which knows about its alias'd tensors.
  41. //
  42. // See Note [Functionalization: Alias Removal] for details on the aliasing
  43. // machinery. See Note [Functionalization: Mutation Removal] for details on
  44. // mutation removal.
  45. struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
  46. explicit FunctionalTensorWrapper(const Tensor& value);
  47. // Additional constructor to create a FunctionalTensorWrapper directly from an
  48. // underlying tensor that was created from a view. For example, the code b =
  49. // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
  50. // view1_meta)
  51. explicit FunctionalTensorWrapper(
  52. const Tensor& view_value,
  53. const FunctionalTensorWrapper* base,
  54. const std::shared_ptr<functionalization::ViewMeta>& meta);
  55. // Get the underlying, actual tensor, that doesn't know anything about
  56. // functionalization.
  57. const Tensor& value() const {
  58. return value_;
  59. }
  60. // The concept of "level" is only ever important to functorch; it's exposed
  61. // here as more of a hook for functorch to use.
  62. int64_t level() const {
  63. return level_;
  64. }
  65. void set_level(int64_t level) {
  66. level_ = level;
  67. }
  68. bool has_metadata_mutation() const {
  69. return has_metadata_mutation_;
  70. }
  71. uint64_t mutation_counter() const {
  72. return functional_storage_impl()->mutation_counter();
  73. }
  74. void mark_mutation() {
  75. functional_storage_impl()->mark_mutation();
  76. }
  77. // Denotes a mutation that's hidden from autograd,
  78. // e.g. for the purposes of passing a tensor to a triton kernel
  79. void mark_mutation_hidden_from_autograd() {
  80. functional_storage_impl()->mark_mutation_hidden_from_autograd();
  81. }
  82. void mark_mutation_during_no_grad_or_inference_mode() {
  83. functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode();
  84. }
  85. // Are all the mutations happening to the tensor hidden from autograd
  86. bool are_all_mutations_hidden_from_autograd() const {
  87. return functional_storage_impl()->are_all_mutations_hidden_from_autograd();
  88. }
  89. // Did all mutations happen under no_grad or inference_mode
  90. // (We also need to ignore mutations fully hidden from autograd here)
  91. bool are_all_mutations_under_no_grad_or_inference_mode() const {
  92. return functional_storage_impl()
  93. ->are_all_mutations_under_no_grad_or_inference_mode();
  94. }
  95. void maybe_mark_symbolic(functionalization::ViewMeta* meta) {
  96. is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs;
  97. }
  98. bool is_symbolic() const {
  99. return is_symbolic_;
  100. }
  101. // Retrieves the ViewMeta sequence of this tensor.
  102. const std::vector<std::shared_ptr<functionalization::ViewMeta>>& view_metas()
  103. const;
  104. // Sync's the underlying tensor with its alias, if it's out of date. This
  105. // involves two steps: 1) Apply any pending updates/mutations to the alias 2)
  106. // Replay the views (if any) to regenerate the current tensor off of the
  107. // updated alias.
  108. void sync_();
  109. // Performs step (1) of the sync. This is its own public API because it's
  110. // needed by view_inplace ops like transpose_. See Note [Functionalization
  111. // Pass - Inplace View Ops]
  112. void regenerate_from_base();
  113. // Performs step (2) of the sync. This is its own public API because it's
  114. // needed by functorch. functorch wants to make sure that all input tensors to
  115. // a functionalized program have been properly synced so it can properly
  116. // propagate mutations to inputs. It can't just call sync_(), because the
  117. // FunctionalTensorWrapper will look like it has no aliases and sync_ will be
  118. // a noop. We use the reference count on storage_ to determine if the wrapper
  119. // is aliased, and by the time functorch is ready to propagate updates to
  120. // inputs, any intermediate views of the input created by the program will
  121. // have been deallocated. This function also returns whether or not the base
  122. // actually had any updates to apply.
  123. bool apply_updates();
  124. // Takes the current state of value_ and snapshots it, sending it as a pending
  125. // update to the alias.
  126. void commit_update();
  127. // When any tensor is mutated, the tensor increments its alias's "generation".
  128. // Separately, each tensor maintains its own "generation" counter, which is
  129. // used to determine if it's up-to-date with its alias. The act of syncing a
  130. // tensor will set a tensor's generation equal to its alias's generation.
  131. bool is_up_to_date() const;
  132. // Freezes the storage of this tensor, preventing subsequent mutations
  133. void freeze_storage() const;
  134. // Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
  135. // describing the series of view ops that ran to generate the current tensor
  136. // from the base tensor. This method is used by inplace-view ops like
  137. // transpose_. It appends a ViewMeta to the existing stack, and refreshes the
  138. // tensor by replaying the views off of the alias.
  139. void mutate_view_meta(
  140. const std::shared_ptr<at::functionalization::ViewMeta>& meta);
  141. // Custom implementation of self.set_(src)
  142. void set__impl(const FunctionalTensorWrapper* other);
  143. // Custom implementation of resize_storage_bytes_(self, new_size)
  144. void storage_resize_(const c10::SymInt& new_size);
  145. // Returns whether the current tensor's data was ever mutated
  146. bool has_data_mutation();
  147. //
  148. // Returns whether the current FunctionalTensorWrapper
  149. // experienced a set_() call.
  150. bool was_storage_changed() {
  151. return was_storage_changed_;
  152. }
  153. void mark_storage_changed() {
  154. was_storage_changed_ = true;
  155. storage_changed_counter_++;
  156. }
  157. uint64_t storage_changed_counter() {
  158. return storage_changed_counter_;
  159. }
  160. // A FunctionalTensor is considered a base if its not a view of another
  161. // tensor.
  162. bool isBaseTensor() const {
  163. return view_metas_.empty();
  164. }
  165. c10::SymInt get_storage_size(bool before) {
  166. return functional_storage_impl()->get_storage_size(before);
  167. }
  168. // Returns whether the FunctionalTensor experienced an
  169. // untyped_storage().resize_() call
  170. bool was_inductor_storage_resized() {
  171. return functional_storage_impl()->was_inductor_storage_resized();
  172. }
  173. bool inductor_storage_resized_counter() {
  174. return functional_storage_impl()->inductor_storage_resized_counter();
  175. }
  176. // The functionalization pass can be used to remove mutations.
  177. // It does so by replacing any mutation op with it's corresponding
  178. // out-of-place op, followed by a call to replace_(). e.g:
  179. //
  180. // a.add_(1)
  181. //
  182. // will turn into:
  183. //
  184. // tmp = a.add(1)
  185. // a.replace_(tmp)
  186. //
  187. // replace_() swaps out the wrapped tensor, value_, with tmp.
  188. void replace_(const Tensor& other, bool from_lazy_regenerate = false);
  189. bool is_multi_output_view() {
  190. return is_multi_output_view_;
  191. }
  192. // See Note[resize_() in functionalization pass]
  193. void maybe_replace_storage(const Tensor& other);
  194. // Replaces the storage with a new functional storage,
  195. // and clears the view_metas_ stack.
  196. // WARNING: Calling this function will sever the aliasing relationship between
  197. // the current FunctionalTensorWrapper and any of its outstanding aliases.
  198. // Please only call if you know what you're doing.
  199. void _unsafe_reset_storage();
  200. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  201. const c10::VariableVersion& version_counter,
  202. bool allow_tensor_metadata_change) const override;
  203. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  204. c10::VariableVersion&& version_counter,
  205. bool allow_tensor_metadata_change) const override;
  206. ~FunctionalTensorWrapper() override = default;
  207. // FunctionalTensorWrapper overrides all custom size/stride function,
  208. // so that if the inner tensor has a custom implementation
  209. // we make sure to call that implementation.
  210. at::IntArrayRef sizes_custom() const override;
  211. at::IntArrayRef strides_custom() const override;
  212. int64_t dim_custom() const override;
  213. int64_t numel_custom() const override;
  214. c10::SymBool sym_is_contiguous_custom(
  215. at::MemoryFormat memory_format) const override;
  216. c10::SymIntArrayRef sym_sizes_custom() const override;
  217. c10::SymInt sym_size_custom(int64_t d) const override;
  218. c10::SymIntArrayRef sym_strides_custom() const override;
  219. c10::SymInt sym_storage_offset_custom() const override;
  220. c10::Device device_custom() const override;
  221. c10::Layout layout_impl() const override;
  222. private:
  223. const char* tensorimpl_type_name() const override;
  224. void set_constructor_metadata();
  225. functionalization::FunctionalStorageImpl* functional_storage_impl() const;
  226. // This is used to re-implement shallow_copy_and_detach for
  227. // FunctionalTensorWrapper. The implementation is identical, but we just need
  228. // to return a subclass instead of a plain TensorImpl.
  229. // TODO: maybe it's possible to arrange for that to happen automatically
  230. // without an override here?
  231. template <typename VariableVersion>
  232. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
  233. VariableVersion&& version_counter,
  234. bool allow_tensor_metadata_change) const;
  235. void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
  236. void copy_tensor_metadata_and_refresh(
  237. const FunctionalTensorWrapper* src_impl,
  238. FunctionalTensorWrapper* dest_impl,
  239. const c10::VariableVersion& version_counter,
  240. bool allow_tensor_metadata_change) const;
  241. // Note that value is not taken by reference: internally, the wrapper will
  242. // change the value tensor that it points to over time.
  243. Tensor value_;
  244. int64_t level_{};
  245. // These two counters are used for identifying
  246. // whether all the mutations on a given tensor are hidden from autograd or
  247. // not. If we have an input mutation that is hidden from autograd, then once
  248. // we convert the input mutation to a copy_() we know it will be safe to hide
  249. // the copy_() from autograd as well.
  250. bool has_metadata_mutation_ = false;
  251. bool is_multi_output_view_ = false;
  252. // Did the tensor experience a set_() call.
  253. bool was_storage_changed_ = false;
  254. uint64_t storage_changed_counter_ = 0;
  255. // Did the tensor experience any view operation with symbolic int.
  256. bool is_symbolic_ = false;
  257. size_t generation_ = 0;
  258. std::vector<std::shared_ptr<at::functionalization::ViewMeta>> view_metas_;
  259. protected:
  260. static void copy_tensor_metadata(
  261. const FunctionalTensorWrapper* src_impl,
  262. FunctionalTensorWrapper* dest_impl,
  263. const c10::VariableVersion& version_counter,
  264. bool allow_tensor_metadata_change);
  265. };
  266. // Utility functions for the functionalization pass.
  267. namespace functionalization {
  268. namespace impl {
  269. inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
  270. const Tensor& tensor) {
  271. auto functional_impl =
  272. static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
  273. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
  274. return functional_impl;
  275. }
  276. TORCH_API bool isBaseTensor(const at::Tensor& tensor);
  277. TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
  278. TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
  279. TORCH_API bool isFunctionalTensor(
  280. const c10::List<std::optional<Tensor>>& t_list);
  281. TORCH_API bool isFunctionalTensor(ITensorListRef list);
  282. TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
  283. TORCH_API std::optional<Tensor> to_functional_tensor(
  284. const std::optional<Tensor>& tensor);
  285. TORCH_API c10::List<std::optional<Tensor>> to_functional_tensor(
  286. const c10::List<std::optional<Tensor>>& t_list);
  287. TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
  288. TORCH_API void freeze_functional_tensor(const Tensor& tensor);
  289. TORCH_API Tensor
  290. from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
  291. TORCH_API std::optional<Tensor> from_functional_tensor(
  292. const std::optional<Tensor>& t,
  293. bool assert_functional = true);
  294. TORCH_API c10::List<std::optional<Tensor>> from_functional_tensor(
  295. const c10::List<std::optional<Tensor>>& t_list);
  296. TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
  297. TORCH_API void sync(const at::Tensor& t);
  298. TORCH_API void sync(const std::optional<Tensor>& t);
  299. TORCH_API void sync(const c10::List<std::optional<Tensor>>& t_list);
  300. TORCH_API void sync(ITensorListRef t_list);
  301. TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
  302. TORCH_API void replace_(
  303. const ITensorListRef functional_tensor,
  304. ITensorListRef other);
  305. TORCH_API void commit_update(const Tensor& functional_tensor);
  306. TORCH_API void commit_update(ITensorListRef functional_tensor);
  307. TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
  308. TORCH_API void mark_mutation_hidden_from_autograd(
  309. const Tensor& functional_tensor);
  310. TORCH_API bool are_all_mutations_hidden_from_autograd(
  311. const Tensor& functional_tensor);
  312. TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
  313. const Tensor& functional_tensor);
  314. // These two methods are XLA-specific logic and are no-ops
  315. // for the normal functionalization flow.
  316. TORCH_API void propagate_xla_data(
  317. const Tensor& functional_tensor,
  318. const Tensor& other);
  319. TORCH_API void propagate_xla_data(
  320. const ITensorListRef functional_tensor,
  321. ITensorListRef other);
  322. TORCH_API void propagate_xla_data_direct(
  323. const Tensor& tensor,
  324. const Tensor& other);
  325. TORCH_API void propagate_xla_data_direct(
  326. const ITensorListRef tensor,
  327. ITensorListRef other);
  328. Tensor create_functional_tensor_with_view_meta(
  329. const Tensor& view_to_wrap,
  330. const Tensor& base,
  331. const std::shared_ptr<functionalization::ViewMeta>& meta,
  332. int64_t out_idx = 0);
  333. std::vector<Tensor> create_functional_tensor_with_view_meta(
  334. ITensorListRef view_to_wrap,
  335. const Tensor& base,
  336. const std::shared_ptr<functionalization::ViewMeta>& meta);
  337. void mutate_view_meta(
  338. const Tensor& self,
  339. const std::shared_ptr<functionalization::ViewMeta>& meta);
  340. TORCH_API Tensor apply_view_meta_sequence(
  341. const Tensor& base,
  342. const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence);
  343. void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
  344. void set_sizes_strides_offset(
  345. const std::vector<Tensor>& outs,
  346. const std::vector<Tensor>& meta_outs);
  347. // ~~~~~ TLS used in functionalization ~~~~~
  348. TORCH_API bool getFunctionalizationReapplyViewsTLS();
  349. TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
  350. class TORCH_API FunctionalizationReapplyViewsGuard {
  351. public:
  352. FunctionalizationReapplyViewsGuard(bool reapply_views)
  353. : prev_(getFunctionalizationReapplyViewsTLS()) {
  354. setFunctionalizationReapplyViewsTLS(reapply_views);
  355. }
  356. ~FunctionalizationReapplyViewsGuard() {
  357. setFunctionalizationReapplyViewsTLS(prev_);
  358. }
  359. FunctionalizationReapplyViewsGuard(
  360. const FunctionalizationReapplyViewsGuard&) = delete;
  361. FunctionalizationReapplyViewsGuard operator=(
  362. const FunctionalizationReapplyViewsGuard&) = delete;
  363. FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
  364. delete;
  365. FunctionalizationReapplyViewsGuard operator=(
  366. FunctionalizationReapplyViewsGuard&&) = delete;
  367. private:
  368. bool prev_;
  369. };
  370. } // namespace impl
  371. // Helper function to call an out-of-place composite aten kernel that may use
  372. // mutations / views internally, and functionalize them.
  373. TORCH_API void functionalize_op_helper(
  374. const c10::OperatorHandle& op,
  375. torch::jit::Stack* stack);
  376. template <class Op, bool symint, class ReturnType, class... ParameterTypes>
  377. struct _functionalize_aten_op final {};
  378. template <class Op, bool symint, class ReturnType, class... ParameterTypes>
  379. struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
  380. static ReturnType call(
  381. typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
  382. using FuncType = ReturnType(
  383. typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
  384. auto op = c10::Dispatcher::singleton()
  385. .findSchemaOrThrow(
  386. (const char*)Op::name, (const char*)Op::overload_name)
  387. .typed<FuncType>();
  388. return c10::impl::BoxedKernelWrapper<FuncType>::call(
  389. c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
  390. op,
  391. // BoxedKernelWrapper knows to ignore this keyset argument,
  392. // because functionalize_op_helper doesn't take in a DispatchKeySet
  393. c10::DispatchKeySet(),
  394. args...);
  395. }
  396. };
  397. template <class Op>
  398. using functionalize_aten_op =
  399. _functionalize_aten_op<Op, false, typename Op::schema>;
  400. template <class Op>
  401. using functionalize_aten_op_symint =
  402. _functionalize_aten_op<Op, true, typename Op::schema>;
  403. } // namespace functionalization
  404. } // namespace at