gen_schema_utils.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from typing import Any, Optional, Union
  2. from torchgen.model import (
  3. Annotation,
  4. Argument,
  5. Arguments,
  6. BaseOperatorName,
  7. BaseTy,
  8. BaseType,
  9. CustomClassType,
  10. FunctionSchema,
  11. ListType,
  12. OperatorName,
  13. Return,
  14. )
  15. # Note: These aren't actually used in torchgen, they're some utilities for generating a schema
  16. # from real arguments. For example, this is used to generate HigherOrderOperators' schema since
  17. # their schemas can vary for different instances of the same HOP.
  18. class TypeGen:
  19. convert_to_base_ty = {
  20. int: BaseTy.int,
  21. float: BaseTy.float,
  22. str: BaseTy.str,
  23. bool: BaseTy.bool,
  24. }
  25. @staticmethod
  26. def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]:
  27. import torch
  28. if isinstance(obj, torch.fx.GraphModule):
  29. return BaseType(BaseTy.GraphModule)
  30. elif isinstance(obj, torch.Tensor):
  31. return BaseType(BaseTy.Tensor)
  32. elif isinstance(obj, torch.SymInt):
  33. return BaseType(BaseTy.SymInt)
  34. elif isinstance(obj, torch.SymBool):
  35. return BaseType(BaseTy.SymBool)
  36. elif isinstance(obj, torch.ScriptObject):
  37. return CustomClassType(obj._type().name()) # type: ignore[attr-defined]
  38. elif isinstance(obj, (list, tuple)):
  39. assert len(obj) > 0
  40. all_base_tys = [TypeGen.from_example(x) for x in obj]
  41. if len(set(all_base_tys)) > 1:
  42. raise RuntimeError(
  43. f"Cannot generate schema for a sequence of args of heterogeneous types: {all_base_tys}. "
  44. "Consider unpacking the argument and give proper names to them if possible "
  45. "instead of using *args."
  46. )
  47. return ListType(all_base_tys[0], len(obj))
  48. tp = type(obj)
  49. if tp not in TypeGen.convert_to_base_ty:
  50. raise RuntimeError(f"unsupported type {tp}")
  51. return BaseType(TypeGen.convert_to_base_ty[tp])
  52. class ReturnGen:
  53. @staticmethod
  54. def from_example(
  55. name: Optional[str], obj: Any, annotation: Optional[Annotation]
  56. ) -> Return:
  57. return Return(name, TypeGen.from_example(obj), annotation)
  58. class ArgumentGen:
  59. @staticmethod
  60. def from_example(
  61. name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation]
  62. ) -> Argument:
  63. return Argument(
  64. name, TypeGen.from_example(obj), default=default, annotation=annotation
  65. )
  66. class FunctionSchemaGen:
  67. @staticmethod
  68. def from_example(
  69. op_name: str,
  70. example_inputs: tuple[tuple[str, Any], ...],
  71. example_outputs: tuple[Any, ...],
  72. ) -> FunctionSchema:
  73. args = []
  74. for name, inp in example_inputs:
  75. args.append(ArgumentGen.from_example(name, inp, None, None))
  76. # ignore the annotations and other attributes for now, we could add more when needed.
  77. arguments = Arguments(
  78. tuple(), None, tuple(args), tuple(), None, tuple(), tuple()
  79. )
  80. returns = tuple(
  81. ReturnGen.from_example(None, out, None) for out in example_outputs
  82. )
  83. op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "")
  84. return FunctionSchema(op_name, arguments, returns)