FunctionalizeFallbackKernel.h 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #pragma once
  2. #include <ATen/FunctionalStorageImpl.h>
  3. namespace at::functionalization {
  4. // `ViewMeta` implementation for `resize_` operation.
  5. struct TORCH_API resize__ViewMeta : public ViewMeta {
  6. FUNCTIONALIZATION_VIEWMETA_NAME(resize__ViewMeta)
  7. FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
  8. bool /* reapply_views */,
  9. const std::vector<int64_t>&);
  10. resize__ViewMeta(const SerializableTuple& tpl)
  11. : resize__ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
  12. resize__ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
  13. : ViewMeta(/*has_symbolic_inputs=*/false),
  14. reapply_views(reapply_views),
  15. size(size) {}
  16. Tensor forward(const Tensor& base) override;
  17. Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
  18. SerializableTuple to_serializable_tuple() {
  19. return std::make_tuple(reapply_views, size);
  20. }
  21. bool reapply_views;
  22. std::vector<int64_t> size;
  23. };
  24. // `ViewMeta` implementation for `_unsafe_view` operation.
  25. struct TORCH_API _unsafe_view_ViewMeta : public ViewMeta {
  26. FUNCTIONALIZATION_VIEWMETA_NAME(_unsafe_view_ViewMeta)
  27. FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
  28. bool /* has_symbolic_inputs */,
  29. const std::vector<c10::SymInt>&);
  30. _unsafe_view_ViewMeta(const SerializableTuple& tpl)
  31. : _unsafe_view_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
  32. _unsafe_view_ViewMeta(
  33. bool has_symbolic_inputs,
  34. const std::vector<c10::SymInt>& size)
  35. : ViewMeta(has_symbolic_inputs), size(size) {}
  36. Tensor forward(const Tensor& base) override;
  37. Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
  38. SerializableTuple to_serializable_tuple() {
  39. return std::make_tuple(has_symbolic_inputs, size);
  40. }
  41. std::vector<c10::SymInt> size;
  42. };
  43. } // namespace at::functionalization