| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- from typing import Any, Optional, Union
- from torchgen.model import (
- Annotation,
- Argument,
- Arguments,
- BaseOperatorName,
- BaseTy,
- BaseType,
- CustomClassType,
- FunctionSchema,
- ListType,
- OperatorName,
- Return,
- )
- # Note: These aren't actually used in torchgen, they're some utilities for generating a schema
- # from real arguments. For example, this is used to generate HigherOrderOperators' schema since
- # their schemas can vary for different instances of the same HOP.
- class TypeGen:
- convert_to_base_ty = {
- int: BaseTy.int,
- float: BaseTy.float,
- str: BaseTy.str,
- bool: BaseTy.bool,
- }
- @staticmethod
- def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]:
- import torch
- if isinstance(obj, torch.fx.GraphModule):
- return BaseType(BaseTy.GraphModule)
- elif isinstance(obj, torch.Tensor):
- return BaseType(BaseTy.Tensor)
- elif isinstance(obj, torch.SymInt):
- return BaseType(BaseTy.SymInt)
- elif isinstance(obj, torch.SymBool):
- return BaseType(BaseTy.SymBool)
- elif isinstance(obj, torch.ScriptObject):
- return CustomClassType(obj._type().name()) # type: ignore[attr-defined]
- elif isinstance(obj, (list, tuple)):
- assert len(obj) > 0
- all_base_tys = [TypeGen.from_example(x) for x in obj]
- if len(set(all_base_tys)) > 1:
- raise RuntimeError(
- f"Cannot generate schema for a sequence of args of heterogeneous types: {all_base_tys}. "
- "Consider unpacking the argument and give proper names to them if possible "
- "instead of using *args."
- )
- return ListType(all_base_tys[0], len(obj))
- tp = type(obj)
- if tp not in TypeGen.convert_to_base_ty:
- raise RuntimeError(f"unsupported type {tp}")
- return BaseType(TypeGen.convert_to_base_ty[tp])
- class ReturnGen:
- @staticmethod
- def from_example(
- name: Optional[str], obj: Any, annotation: Optional[Annotation]
- ) -> Return:
- return Return(name, TypeGen.from_example(obj), annotation)
- class ArgumentGen:
- @staticmethod
- def from_example(
- name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation]
- ) -> Argument:
- return Argument(
- name, TypeGen.from_example(obj), default=default, annotation=annotation
- )
- class FunctionSchemaGen:
- @staticmethod
- def from_example(
- op_name: str,
- example_inputs: tuple[tuple[str, Any], ...],
- example_outputs: tuple[Any, ...],
- ) -> FunctionSchema:
- args = []
- for name, inp in example_inputs:
- args.append(ArgumentGen.from_example(name, inp, None, None))
- # ignore the annotations and other attributes for now, we could add more when needed.
- arguments = Arguments(
- tuple(), None, tuple(args), tuple(), None, tuple(), tuple()
- )
- returns = tuple(
- ReturnGen.from_example(None, out, None) for out in example_outputs
- )
- op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "")
- return FunctionSchema(op_name, arguments, returns)
|